From 3ff81dc9c9ce20efd3bf54cf52adaf438c483a92 Mon Sep 17 00:00:00 2001 From: Mykola Mokhnach Date: Tue, 7 Mar 2023 20:56:51 +0100 Subject: [PATCH] refactor: Extract zlib-related logic into a single module (#7223) ## What do these changes do? Addresses issue https://github.com/aio-libs/aiohttp/issues/7192 Refactors the logic to have the zlib-related stuff concentrated into a single module ## Are there changes in behavior for the user? No ## Related issue number https://github.com/aio-libs/aiohttp/issues/7192 ## Checklist - [x] I think the code is well written - [x] Unit tests for the changes exist - [ ] Documentation reflects the changes - [ ] If you provide code modification, please add yourself to `CONTRIBUTORS.txt` * The format is <Name> <Surname>. * Please keep alphabetical order, the file is sorted by names. - [ ] Add a new news fragment into the `CHANGES` folder * name it `.` for example (588.bugfix) * if you don't have an `issue_id` change it to the pr id after creating the pr * ensure type is one of the following: * `.feature`: Signifying a new feature. * `.bugfix`: Signifying a bug fix. * `.doc`: Signifying a documentation improvement. * `.removal`: Signifying a deprecation or removal of public API. * `.misc`: A ticket has been closed, but it is not of interest to users. * Make sure to use full sentences with correct case and punctuation, for example: "Fix issue with non-ascii contents in doctest text files." --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sam Bull --- aiohttp/client_reqrep.py | 2 +- aiohttp/compression_utils.py | 148 +++++++++++++++++++++++++++++++++++ aiohttp/http_parser.py | 44 +++-------- aiohttp/http_websocket.py | 16 ++-- aiohttp/http_writer.py | 12 +-- aiohttp/multipart.py | 28 ++++--- aiohttp/web_response.py | 41 ++++------ tests/test_http_parser.py | 4 +- 8 files changed, 208 insertions(+), 87 deletions(-) create mode 100644 aiohttp/compression_utils.py diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 5a705397cdf..dadc4e2fa7b 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -37,6 +37,7 @@ InvalidURL, ServerFingerprintMismatch, ) +from .compression_utils import HAS_BROTLI from .formdata import FormData from .hdrs import CONTENT_TYPE from .helpers import ( @@ -51,7 +52,6 @@ set_result, ) from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter -from .http_parser import HAS_BROTLI from .log import client_logger from .streams import StreamReader from .typedefs import ( diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py new file mode 100644 index 00000000000..8abc4fa7c3c --- /dev/null +++ b/aiohttp/compression_utils.py @@ -0,0 +1,148 @@ +import asyncio +import zlib +from concurrent.futures import Executor +from typing import Optional, cast + +try: + import brotli + + HAS_BROTLI = True +except ImportError: # pragma: no cover + HAS_BROTLI = False + +MAX_SYNC_CHUNK_SIZE = 1024 + + +def encoding_to_mode( + encoding: Optional[str] = None, + suppress_deflate_header: bool = False, +) -> int: + if encoding == "gzip": + return 16 + zlib.MAX_WBITS + + return -zlib.MAX_WBITS if suppress_deflate_header else zlib.MAX_WBITS + + +class ZlibBaseHandler: + def __init__( + self, + mode: int, + executor: Optional[Executor] = None, + max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, + ): + self._mode = mode + self._executor = executor + self._max_sync_chunk_size = max_sync_chunk_size + + +class ZLibCompressor(ZlibBaseHandler): + def __init__( + self, + encoding: Optional[str] = None, + suppress_deflate_header: bool = False, + level: Optional[int] = None, + wbits: Optional[int] = None, + strategy: int = zlib.Z_DEFAULT_STRATEGY, + executor: Optional[Executor] = None, + max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, + ): + super().__init__( + mode=encoding_to_mode(encoding, suppress_deflate_header) + if wbits is None + else wbits, + executor=executor, + max_sync_chunk_size=max_sync_chunk_size, + ) + if level is None: + self._compressor = zlib.compressobj(wbits=self._mode, strategy=strategy) + else: + self._compressor = zlib.compressobj( + wbits=self._mode, strategy=strategy, level=level + ) + + def compress_sync(self, data: bytes) -> bytes: + return self._compressor.compress(data) + + async def compress(self, data: bytes) -> bytes: + if ( + self._max_sync_chunk_size is not None + and len(data) > self._max_sync_chunk_size + ): + return await asyncio.get_event_loop().run_in_executor( + self._executor, self.compress_sync, data + ) + return self.compress_sync(data) + + def flush(self, mode: int = zlib.Z_FINISH) -> bytes: + return self._compressor.flush(mode) + + +class ZLibDecompressor(ZlibBaseHandler): + def __init__( + self, + encoding: Optional[str] = None, + suppress_deflate_header: bool = False, + executor: Optional[Executor] = None, + max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, + ): + super().__init__( + mode=encoding_to_mode(encoding, suppress_deflate_header), + executor=executor, + max_sync_chunk_size=max_sync_chunk_size, + ) + self._decompressor = zlib.decompressobj(wbits=self._mode) + + def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes: + return self._decompressor.decompress(data, max_length) + + async def decompress(self, data: bytes, max_length: int = 0) -> bytes: + if ( + self._max_sync_chunk_size is not None + and len(data) > self._max_sync_chunk_size + ): + return await asyncio.get_event_loop().run_in_executor( + self._executor, self.decompress_sync, data, max_length + ) + return self.decompress_sync(data, max_length) + + def flush(self, length: int = 0) -> bytes: + return ( + self._decompressor.flush(length) + if length > 0 + else self._decompressor.flush() + ) + + @property + def eof(self) -> bool: + return self._decompressor.eof + + @property + def unconsumed_tail(self) -> bytes: + return self._decompressor.unconsumed_tail + + @property + def unused_data(self) -> bytes: + return self._decompressor.unused_data + + +class BrotliDecompressor: + # Supports both 'brotlipy' and 'Brotli' packages + # since they share an import name. The top branches + # are for 'brotlipy' and bottom branches for 'Brotli' + def __init__(self) -> None: + if not HAS_BROTLI: + raise RuntimeError( + "The brotli decompression is not available. " + "Please install `Brotli` module" + ) + self._obj = brotli.Decompressor() + + def decompress_sync(self, data: bytes) -> bytes: + if hasattr(self._obj, "decompress"): + return cast(bytes, self._obj.decompress(data)) + return cast(bytes, self._obj.process(data)) + + def flush(self) -> bytes: + if hasattr(self._obj, "flush"): + return cast(bytes, self._obj.flush()) + return b"" diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index b9b2488d2a6..f5de343326f 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -3,11 +3,9 @@ import collections import re import string -import zlib from contextlib import suppress from enum import IntEnum from typing import ( - Any, Generic, List, NamedTuple, @@ -18,7 +16,6 @@ Type, TypeVar, Union, - cast, ) from multidict import CIMultiDict, CIMultiDictProxy, istr @@ -27,6 +24,7 @@ from . import hdrs from .base_protocol import BaseProtocol +from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor from .helpers import NO_EXTENSIONS, BaseTimerContext from .http_exceptions import ( BadHttpMessage, @@ -42,14 +40,6 @@ from .streams import EMPTY_PAYLOAD, StreamReader from .typedefs import RawHeaders -try: - import brotli - - HAS_BROTLI = True -except ImportError: # pragma: no cover - HAS_BROTLI = False - - __all__ = ( "HeadersParser", "HttpParser", @@ -859,34 +849,16 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None: self.encoding = encoding self._started_decoding = False + self.decompressor: Union[BrotliDecompressor, ZLibDecompressor] if encoding == "br": if not HAS_BROTLI: # pragma: no cover raise ContentEncodingError( "Can not decode content-encoding: brotli (br). " "Please install `Brotli`" ) - - class BrotliDecoder: - # Supports both 'brotlipy' and 'Brotli' packages - # since they share an import name. The top branches - # are for 'brotlipy' and bottom branches for 'Brotli' - def __init__(self) -> None: - self._obj = brotli.Decompressor() - - def decompress(self, data: bytes) -> bytes: - if hasattr(self._obj, "decompress"): - return cast(bytes, self._obj.decompress(data)) - return cast(bytes, self._obj.process(data)) - - def flush(self) -> bytes: - if hasattr(self._obj, "flush"): - return cast(bytes, self._obj.flush()) - return b"" - - self.decompressor: Any = BrotliDecoder() + self.decompressor = BrotliDecompressor() else: - zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS - self.decompressor = zlib.decompressobj(wbits=zlib_mode) + self.decompressor = ZLibDecompressor(encoding=encoding) def set_exception(self, exc: BaseException) -> None: self.out.set_exception(exc) @@ -907,10 +879,12 @@ def feed_data(self, chunk: bytes, size: int) -> None: ): # Change the decoder to decompress incorrectly compressed data # Actually we should issue a warning about non-RFC-compliant data. - self.decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS) + self.decompressor = ZLibDecompressor( + encoding=self.encoding, suppress_deflate_header=True + ) try: - chunk = self.decompressor.decompress(chunk) + chunk = self.decompressor.decompress_sync(chunk) except Exception: raise ContentEncodingError( "Can not decode content-encoding: %s" % self.encoding @@ -926,7 +900,7 @@ def feed_eof(self) -> None: if chunk or self.size > 0: self.out.feed_data(chunk, len(chunk)) - if self.encoding == "deflate" and not self.decompressor.eof: + if self.encoding == "deflate" and not self.decompressor.eof: # type: ignore raise ContentEncodingError("deflate") self.out.feed_eof() diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index fe5058cae62..deaa5c0f067 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -15,6 +15,7 @@ from typing_extensions import Final from .base_protocol import BaseProtocol +from .compression_utils import ZLibCompressor, ZLibDecompressor from .helpers import NO_EXTENSIONS from .streams import DataQueue @@ -270,7 +271,7 @@ def __init__( self._payload_length = 0 self._payload_length_flag = 0 self._compressed: Optional[bool] = None - self._decompressobj: Any = None # zlib.decompressobj actually + self._decompressobj: Optional[ZLibDecompressor] = None self._compress = compress def feed_eof(self) -> None: @@ -290,7 +291,7 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]: def _feed_data(self, data: bytes) -> Tuple[bool, bytes]: for fin, opcode, payload, compressed in self.parse_frame(data): if compressed and not self._decompressobj: - self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS) + self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) if opcode == WSMsgType.CLOSE: if len(payload) >= 2: close_code = UNPACK_CLOSE_CODE(payload[:2])[0] @@ -375,8 +376,9 @@ def _feed_data(self, data: bytes) -> Tuple[bool, bytes]: # Decompress process must to be done after all packets # received. if compressed: + assert self._decompressobj is not None self._partial.extend(_WS_DEFLATE_TRAILING) - payload_merged = self._decompressobj.decompress( + payload_merged = self._decompressobj.decompress_sync( self._partial, self._max_msg_size ) if self._decompressobj.unconsumed_tail: @@ -604,16 +606,16 @@ async def _send_frame( if (compress or self.compress) and opcode < 8: if compress: # Do not set self._compress if compressing is for this frame - compressobj = zlib.compressobj(level=zlib.Z_BEST_SPEED, wbits=-compress) + compressobj = ZLibCompressor(level=zlib.Z_BEST_SPEED, wbits=-compress) else: # self.compress if not self._compressobj: - self._compressobj = zlib.compressobj( + self._compressobj = ZLibCompressor( level=zlib.Z_BEST_SPEED, wbits=-self.compress ) compressobj = self._compressobj - message = compressobj.compress(message) - message = message + compressobj.flush( + message = await compressobj.compress(message) + message += compressobj.flush( zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH ) if message.endswith(_WS_DEFLATE_TRAILING): diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index 73f0f96f0ae..8f2d9086b92 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -8,6 +8,7 @@ from .abc import AbstractStreamWriter from .base_protocol import BaseProtocol +from .compression_utils import ZLibCompressor from .helpers import NO_EXTENSIONS __all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11") @@ -43,7 +44,7 @@ def __init__( self.output_size = 0 self._eof = False - self._compress: Any = None + self._compress: Optional[ZLibCompressor] = None self._drain_waiter = None self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent @@ -63,8 +64,7 @@ def enable_chunking(self) -> None: def enable_compression( self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY ) -> None: - zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS - self._compress = zlib.compressobj(wbits=zlib_mode, strategy=strategy) + self._compress = ZLibCompressor(encoding=encoding, strategy=strategy) def _write(self, chunk: bytes) -> None: size = len(chunk) @@ -93,7 +93,7 @@ async def write( chunk = chunk.cast("c") if self._compress is not None: - chunk = self._compress.compress(chunk) + chunk = await self._compress.compress(chunk) if not chunk: return @@ -138,9 +138,9 @@ async def write_eof(self, chunk: bytes = b"") -> None: if self._compress: if chunk: - chunk = self._compress.compress(chunk) + chunk = await self._compress.compress(chunk) - chunk = chunk + self._compress.flush() + chunk += self._compress.flush() if chunk and self.chunked: chunk_len = ("%x\r\n" % len(chunk)).encode("ascii") chunk = chunk_len + chunk + b"\r\n0\r\n\r\n" diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 942dda507ab..0eecb48ddfc 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -27,6 +27,7 @@ from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping +from .compression_utils import ZLibCompressor, ZLibDecompressor from .hdrs import ( CONTENT_DISPOSITION, CONTENT_ENCODING, @@ -491,15 +492,15 @@ def decode(self, data: bytes) -> bytes: def _decode_content(self, data: bytes) -> bytes: encoding = self.headers.get(CONTENT_ENCODING, "").lower() - - if encoding == "deflate": - return zlib.decompress(data, -zlib.MAX_WBITS) - elif encoding == "gzip": - return zlib.decompress(data, 16 + zlib.MAX_WBITS) - elif encoding == "identity": + if encoding == "identity": return data - else: - raise RuntimeError(f"unknown content encoding: {encoding}") + if encoding in {"deflate", "gzip"}: + return ZLibDecompressor( + encoding=encoding, + suppress_deflate_header=True, + ).decompress_sync(data) + + raise RuntimeError(f"unknown content encoding: {encoding}") def _decode_content_transfer(self, data: bytes) -> bytes: encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower() @@ -976,7 +977,7 @@ class MultipartPayloadWriter: def __init__(self, writer: Any) -> None: self._writer = writer self._encoding: Optional[str] = None - self._compress: Any = None + self._compress: Optional[ZLibCompressor] = None self._encoding_buffer: Optional[bytearray] = None def enable_encoding(self, encoding: str) -> None: @@ -989,8 +990,11 @@ def enable_encoding(self, encoding: str) -> None: def enable_compression( self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY ) -> None: - zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else -zlib.MAX_WBITS - self._compress = zlib.compressobj(wbits=zlib_mode, strategy=strategy) + self._compress = ZLibCompressor( + encoding=encoding, + suppress_deflate_header=True, + strategy=strategy, + ) async def write_eof(self) -> None: if self._compress is not None: @@ -1006,7 +1010,7 @@ async def write_eof(self) -> None: async def write(self, chunk: bytes) -> None: if self._compress is not None: if chunk: - chunk = self._compress.compress(chunk) + chunk = await self._compress.compress(chunk) if not chunk: return diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 87c70487e2d..886b8eb76c6 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -6,7 +6,6 @@ import math import time import warnings -import zlib from concurrent.futures import Executor from http import HTTPStatus from http.cookies import Morsel @@ -25,6 +24,7 @@ from . import hdrs, payload from .abc import AbstractStreamWriter +from .compression_utils import ZLibCompressor from .helpers import ( ETAG_ANY, PY_38, @@ -692,13 +692,6 @@ async def _start(self, request: "BaseRequest") -> AbstractStreamWriter: return await super()._start(request) - def _compress_body(self, zlib_mode: int) -> None: - assert zlib_mode > 0 - compressobj = zlib.compressobj(wbits=zlib_mode) - body_in = self._body - assert body_in is not None - self._compressed_body = compressobj.compress(body_in) + compressobj.flush() - async def _do_start_compression(self, coding: ContentCoding) -> None: if self._body_payload or self._chunked: return await super()._do_start_compression(coding) @@ -706,26 +699,26 @@ async def _do_start_compression(self, coding: ContentCoding) -> None: if coding != ContentCoding.identity: # Instead of using _payload_writer.enable_compression, # compress the whole body - zlib_mode = ( - 16 + zlib.MAX_WBITS if coding == ContentCoding.gzip else zlib.MAX_WBITS + compressor = ZLibCompressor( + encoding=str(coding.value), + max_sync_chunk_size=self._zlib_executor_size, + executor=self._zlib_executor, ) - body_in = self._body - assert body_in is not None - if ( - self._zlib_executor_size is not None - and len(body_in) > self._zlib_executor_size - ): - await asyncio.get_event_loop().run_in_executor( - self._zlib_executor, self._compress_body, zlib_mode + assert self._body is not None + if self._zlib_executor_size is None and len(self._body) > 1024 * 1024: + warnings.warn( + "Synchronous compression of large response bodies " + f"({len(self._body)} bytes) might block the async event loop. " + "Consider providing a custom value to zlib_executor_size/" + "zlib_executor response properties or disabling compression on it." ) - else: - self._compress_body(zlib_mode) - - body_out = self._compressed_body - assert body_out is not None + self._compressed_body = ( + await compressor.compress(self._body) + compressor.flush() + ) + assert self._compressed_body is not None self._headers[hdrs.CONTENT_ENCODING] = coding.value - self._headers[hdrs.CONTENT_LENGTH] = str(len(body_out)) + self._headers[hdrs.CONTENT_LENGTH] = str(len(self._compressed_body)) def json_response( diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 619a95b4b6f..e4fcbff0e62 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -1119,7 +1119,7 @@ async def test_feed_data(self, stream: Any) -> None: dbuf = DeflateBuffer(buf, "deflate") dbuf.decompressor = mock.Mock() - dbuf.decompressor.decompress.return_value = b"line" + dbuf.decompressor.decompress_sync.return_value = b"line" # First byte should be b'x' in order code not to change the decoder. dbuf.feed_data(b"xxxx", 4) @@ -1133,7 +1133,7 @@ async def test_feed_data_err(self, stream: Any) -> None: exc = ValueError() dbuf.decompressor = mock.Mock() - dbuf.decompressor.decompress.side_effect = exc + dbuf.decompressor.decompress_sync.side_effect = exc with pytest.raises(http_exceptions.ContentEncodingError): # Should be more than 4 bytes to trigger deflate FSM error.