Skip to content

Commit

Permalink
refactor: Extract zlib-related logic into a single module (#7223)
Browse files Browse the repository at this point in the history
<!-- Thank you for your contribution! -->

## What do these changes do?

Addresses issue #7192
Refactors the logic to have the zlib-related stuff concentrated into a
single module

## Are there changes in behavior for the user?

No

## Related issue number

#7192

## Checklist

- [x] I think the code is well written
- [x] Unit tests for the changes exist
- [ ] Documentation reflects the changes
- [ ] If you provide code modification, please add yourself to
`CONTRIBUTORS.txt`
  * The format is &lt;Name&gt; &lt;Surname&gt;.
  * Please keep alphabetical order, the file is sorted by names.
- [ ] Add a new news fragment into the `CHANGES` folder
  * name it `<issue_id>.<type>` for example (588.bugfix)
* if you don't have an `issue_id` change it to the pr id after creating
the pr
  * ensure type is one of the following:
    * `.feature`: Signifying a new feature.
    * `.bugfix`: Signifying a bug fix.
    * `.doc`: Signifying a documentation improvement.
    * `.removal`: Signifying a deprecation or removal of public API.
* `.misc`: A ticket has been closed, but it is not of interest to users.
* Make sure to use full sentences with correct case and punctuation, for
example: "Fix issue with non-ascii contents in doctest text files."

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sam Bull <[email protected]>
  • Loading branch information
3 people authored Mar 7, 2023
1 parent 3058c72 commit 3ff81dc
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 87 deletions.
2 changes: 1 addition & 1 deletion aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
InvalidURL,
ServerFingerprintMismatch,
)
from .compression_utils import HAS_BROTLI
from .formdata import FormData
from .hdrs import CONTENT_TYPE
from .helpers import (
Expand All @@ -51,7 +52,6 @@
set_result,
)
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter
from .http_parser import HAS_BROTLI
from .log import client_logger
from .streams import StreamReader
from .typedefs import (
Expand Down
148 changes: 148 additions & 0 deletions aiohttp/compression_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import asyncio
import zlib
from concurrent.futures import Executor
from typing import Optional, cast

try:
import brotli

HAS_BROTLI = True
except ImportError: # pragma: no cover
HAS_BROTLI = False

MAX_SYNC_CHUNK_SIZE = 1024


def encoding_to_mode(
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
) -> int:
if encoding == "gzip":
return 16 + zlib.MAX_WBITS

return -zlib.MAX_WBITS if suppress_deflate_header else zlib.MAX_WBITS


class ZlibBaseHandler:
def __init__(
self,
mode: int,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
self._mode = mode
self._executor = executor
self._max_sync_chunk_size = max_sync_chunk_size


class ZLibCompressor(ZlibBaseHandler):
def __init__(
self,
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
level: Optional[int] = None,
wbits: Optional[int] = None,
strategy: int = zlib.Z_DEFAULT_STRATEGY,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(
mode=encoding_to_mode(encoding, suppress_deflate_header)
if wbits is None
else wbits,
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
)
if level is None:
self._compressor = zlib.compressobj(wbits=self._mode, strategy=strategy)
else:
self._compressor = zlib.compressobj(
wbits=self._mode, strategy=strategy, level=level
)

def compress_sync(self, data: bytes) -> bytes:
return self._compressor.compress(data)

async def compress(self, data: bytes) -> bytes:
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_event_loop().run_in_executor(
self._executor, self.compress_sync, data
)
return self.compress_sync(data)

def flush(self, mode: int = zlib.Z_FINISH) -> bytes:
return self._compressor.flush(mode)


class ZLibDecompressor(ZlibBaseHandler):
def __init__(
self,
encoding: Optional[str] = None,
suppress_deflate_header: bool = False,
executor: Optional[Executor] = None,
max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(
mode=encoding_to_mode(encoding, suppress_deflate_header),
executor=executor,
max_sync_chunk_size=max_sync_chunk_size,
)
self._decompressor = zlib.decompressobj(wbits=self._mode)

def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
return self._decompressor.decompress(data, max_length)

async def decompress(self, data: bytes, max_length: int = 0) -> bytes:
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_event_loop().run_in_executor(
self._executor, self.decompress_sync, data, max_length
)
return self.decompress_sync(data, max_length)

def flush(self, length: int = 0) -> bytes:
return (
self._decompressor.flush(length)
if length > 0
else self._decompressor.flush()
)

@property
def eof(self) -> bool:
return self._decompressor.eof

@property
def unconsumed_tail(self) -> bytes:
return self._decompressor.unconsumed_tail

@property
def unused_data(self) -> bytes:
return self._decompressor.unused_data


class BrotliDecompressor:
# Supports both 'brotlipy' and 'Brotli' packages
# since they share an import name. The top branches
# are for 'brotlipy' and bottom branches for 'Brotli'
def __init__(self) -> None:
if not HAS_BROTLI:
raise RuntimeError(
"The brotli decompression is not available. "
"Please install `Brotli` module"
)
self._obj = brotli.Decompressor()

def decompress_sync(self, data: bytes) -> bytes:
if hasattr(self._obj, "decompress"):
return cast(bytes, self._obj.decompress(data))
return cast(bytes, self._obj.process(data))

def flush(self) -> bytes:
if hasattr(self._obj, "flush"):
return cast(bytes, self._obj.flush())
return b""
44 changes: 9 additions & 35 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import collections
import re
import string
import zlib
from contextlib import suppress
from enum import IntEnum
from typing import (
Any,
Generic,
List,
NamedTuple,
Expand All @@ -18,7 +16,6 @@
Type,
TypeVar,
Union,
cast,
)

from multidict import CIMultiDict, CIMultiDictProxy, istr
Expand All @@ -27,6 +24,7 @@

from . import hdrs
from .base_protocol import BaseProtocol
from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor
from .helpers import NO_EXTENSIONS, BaseTimerContext
from .http_exceptions import (
BadHttpMessage,
Expand All @@ -42,14 +40,6 @@
from .streams import EMPTY_PAYLOAD, StreamReader
from .typedefs import RawHeaders

try:
import brotli

HAS_BROTLI = True
except ImportError: # pragma: no cover
HAS_BROTLI = False


__all__ = (
"HeadersParser",
"HttpParser",
Expand Down Expand Up @@ -859,34 +849,16 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None:
self.encoding = encoding
self._started_decoding = False

self.decompressor: Union[BrotliDecompressor, ZLibDecompressor]
if encoding == "br":
if not HAS_BROTLI: # pragma: no cover
raise ContentEncodingError(
"Can not decode content-encoding: brotli (br). "
"Please install `Brotli`"
)

class BrotliDecoder:
# Supports both 'brotlipy' and 'Brotli' packages
# since they share an import name. The top branches
# are for 'brotlipy' and bottom branches for 'Brotli'
def __init__(self) -> None:
self._obj = brotli.Decompressor()

def decompress(self, data: bytes) -> bytes:
if hasattr(self._obj, "decompress"):
return cast(bytes, self._obj.decompress(data))
return cast(bytes, self._obj.process(data))

def flush(self) -> bytes:
if hasattr(self._obj, "flush"):
return cast(bytes, self._obj.flush())
return b""

self.decompressor: Any = BrotliDecoder()
self.decompressor = BrotliDecompressor()
else:
zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS
self.decompressor = zlib.decompressobj(wbits=zlib_mode)
self.decompressor = ZLibDecompressor(encoding=encoding)

def set_exception(self, exc: BaseException) -> None:
self.out.set_exception(exc)
Expand All @@ -907,10 +879,12 @@ def feed_data(self, chunk: bytes, size: int) -> None:
):
# Change the decoder to decompress incorrectly compressed data
# Actually we should issue a warning about non-RFC-compliant data.
self.decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
self.decompressor = ZLibDecompressor(
encoding=self.encoding, suppress_deflate_header=True
)

try:
chunk = self.decompressor.decompress(chunk)
chunk = self.decompressor.decompress_sync(chunk)
except Exception:
raise ContentEncodingError(
"Can not decode content-encoding: %s" % self.encoding
Expand All @@ -926,7 +900,7 @@ def feed_eof(self) -> None:

if chunk or self.size > 0:
self.out.feed_data(chunk, len(chunk))
if self.encoding == "deflate" and not self.decompressor.eof:
if self.encoding == "deflate" and not self.decompressor.eof: # type: ignore
raise ContentEncodingError("deflate")

self.out.feed_eof()
Expand Down
16 changes: 9 additions & 7 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing_extensions import Final

from .base_protocol import BaseProtocol
from .compression_utils import ZLibCompressor, ZLibDecompressor
from .helpers import NO_EXTENSIONS
from .streams import DataQueue

Expand Down Expand Up @@ -270,7 +271,7 @@ def __init__(
self._payload_length = 0
self._payload_length_flag = 0
self._compressed: Optional[bool] = None
self._decompressobj: Any = None # zlib.decompressobj actually
self._decompressobj: Optional[ZLibDecompressor] = None
self._compress = compress

def feed_eof(self) -> None:
Expand All @@ -290,7 +291,7 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
for fin, opcode, payload, compressed in self.parse_frame(data):
if compressed and not self._decompressobj:
self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS)
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
if opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
Expand Down Expand Up @@ -375,8 +376,9 @@ def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
# Decompress process must to be done after all packets
# received.
if compressed:
assert self._decompressobj is not None
self._partial.extend(_WS_DEFLATE_TRAILING)
payload_merged = self._decompressobj.decompress(
payload_merged = self._decompressobj.decompress_sync(
self._partial, self._max_msg_size
)
if self._decompressobj.unconsumed_tail:
Expand Down Expand Up @@ -604,16 +606,16 @@ async def _send_frame(
if (compress or self.compress) and opcode < 8:
if compress:
# Do not set self._compress if compressing is for this frame
compressobj = zlib.compressobj(level=zlib.Z_BEST_SPEED, wbits=-compress)
compressobj = ZLibCompressor(level=zlib.Z_BEST_SPEED, wbits=-compress)
else: # self.compress
if not self._compressobj:
self._compressobj = zlib.compressobj(
self._compressobj = ZLibCompressor(
level=zlib.Z_BEST_SPEED, wbits=-self.compress
)
compressobj = self._compressobj

message = compressobj.compress(message)
message = message + compressobj.flush(
message = await compressobj.compress(message)
message += compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
)
if message.endswith(_WS_DEFLATE_TRAILING):
Expand Down
12 changes: 6 additions & 6 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .compression_utils import ZLibCompressor
from .helpers import NO_EXTENSIONS

__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(
self.output_size = 0

self._eof = False
self._compress: Any = None
self._compress: Optional[ZLibCompressor] = None
self._drain_waiter = None

self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent
Expand All @@ -63,8 +64,7 @@ def enable_chunking(self) -> None:
def enable_compression(
self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY
) -> None:
zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS
self._compress = zlib.compressobj(wbits=zlib_mode, strategy=strategy)
self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)

def _write(self, chunk: bytes) -> None:
size = len(chunk)
Expand Down Expand Up @@ -93,7 +93,7 @@ async def write(
chunk = chunk.cast("c")

if self._compress is not None:
chunk = self._compress.compress(chunk)
chunk = await self._compress.compress(chunk)
if not chunk:
return

Expand Down Expand Up @@ -138,9 +138,9 @@ async def write_eof(self, chunk: bytes = b"") -> None:

if self._compress:
if chunk:
chunk = self._compress.compress(chunk)
chunk = await self._compress.compress(chunk)

chunk = chunk + self._compress.flush()
chunk += self._compress.flush()
if chunk and self.chunked:
chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
Expand Down
Loading

0 comments on commit 3ff81dc

Please sign in to comment.