diff --git a/aiokafka/record/_crecords/default_records.pyi b/aiokafka/record/_crecords/default_records.pyi index 1392aff0..c1815317 100644 --- a/aiokafka/record/_crecords/default_records.pyi +++ b/aiokafka/record/_crecords/default_records.pyi @@ -8,6 +8,15 @@ from aiokafka.record._protocols import ( DefaultRecordMetadataProtocol, DefaultRecordProtocol, ) +from aiokafka.record._types import ( + CodecGzipT, + CodecLz4T, + CodecMaskT, + CodecNoneT, + CodecSnappyT, + CodecZstdT, + DefaultCompressionTypeT, +) @final class DefaultRecord(DefaultRecordProtocol): @@ -33,12 +42,12 @@ class DefaultRecord(DefaultRecordProtocol): @final class DefaultRecordBatch(DefaultRecordBatchProtocol): - CODEC_NONE: ClassVar[int] - CODEC_MASK: ClassVar[int] - CODEC_GZIP: ClassVar[int] - CODEC_SNAPPY: ClassVar[int] - CODEC_LZ4: ClassVar[int] - CODEC_ZSTD: ClassVar[int] + CODEC_MASK: ClassVar[CodecMaskT] + CODEC_NONE: ClassVar[CodecNoneT] + CODEC_GZIP: ClassVar[CodecGzipT] + CODEC_SNAPPY: ClassVar[CodecSnappyT] + CODEC_LZ4: ClassVar[CodecLz4T] + CODEC_ZSTD: ClassVar[CodecZstdT] base_offset: int length: int @@ -75,7 +84,7 @@ class DefaultRecordBatchBuilder(DefaultRecordBatchBuilderProtocol): def __init__( self, magic: int, - compression_type: int, + compression_type: DefaultCompressionTypeT, is_transactional: int, producer_id: int, producer_epoch: int, diff --git a/aiokafka/record/_crecords/legacy_records.pyi b/aiokafka/record/_crecords/legacy_records.pyi index a78d071e..9e304451 100644 --- a/aiokafka/record/_crecords/legacy_records.pyi +++ b/aiokafka/record/_crecords/legacy_records.pyi @@ -8,6 +8,13 @@ from aiokafka.record._protocols import ( LegacyRecordMetadataProtocol, LegacyRecordProtocol, ) +from aiokafka.record._types import ( + CodecGzipT, + CodecLz4T, + CodecMaskT, + CodecSnappyT, + LegacyCompressionTypeT, +) @final class LegacyRecord(LegacyRecordProtocol): @@ -38,10 +45,10 @@ class LegacyRecord(LegacyRecordProtocol): class LegacyRecordBatch(LegacyRecordBatchProtocol): RECORD_OVERHEAD_V0: ClassVar[int] RECORD_OVERHEAD_V1: ClassVar[int] - CODEC_MASK: ClassVar[int] - CODEC_GZIP: ClassVar[int] - CODEC_SNAPPY: ClassVar[int] - CODEC_LZ4: ClassVar[int] + CODEC_MASK: ClassVar[CodecMaskT] + CODEC_GZIP: ClassVar[CodecGzipT] + CODEC_SNAPPY: ClassVar[CodecSnappyT] + CODEC_LZ4: ClassVar[CodecLz4T] is_control_batch: bool is_transactional: bool @@ -54,12 +61,14 @@ class LegacyRecordBatch(LegacyRecordBatchProtocol): @final class LegacyRecordBatchBuilder(LegacyRecordBatchBuilderProtocol): - CODEC_MASK: ClassVar[int] - CODEC_GZIP: ClassVar[int] - CODEC_SNAPPY: ClassVar[int] - CODEC_LZ4: ClassVar[int] + CODEC_MASK: ClassVar[CodecMaskT] + CODEC_GZIP: ClassVar[CodecGzipT] + CODEC_SNAPPY: ClassVar[CodecSnappyT] + CODEC_LZ4: ClassVar[CodecLz4T] - def __init__(self, magic: int, compression_type: int, batch_size: int) -> None: ... + def __init__( + self, magic: int, compression_type: LegacyCompressionTypeT, batch_size: int + ) -> None: ... def append( self, offset: int, diff --git a/aiokafka/record/_protocols.py b/aiokafka/record/_protocols.py index 8dfc4d0e..176932b1 100644 --- a/aiokafka/record/_protocols.py +++ b/aiokafka/record/_protocols.py @@ -15,12 +15,23 @@ from typing_extensions import Literal, Never +from ._types import ( + CodecGzipT, + CodecLz4T, + CodecMaskT, + CodecNoneT, + CodecSnappyT, + CodecZstdT, + DefaultCompressionTypeT, + LegacyCompressionTypeT, +) + class DefaultRecordBatchBuilderProtocol(Protocol): def __init__( self, magic: int, - compression_type: int, + compression_type: DefaultCompressionTypeT, is_transactional: int, producer_id: int, producer_epoch: int, @@ -83,12 +94,12 @@ def timestamp(self) -> int: ... class DefaultRecordBatchProtocol(Iterator["DefaultRecordProtocol"], Protocol): - CODEC_MASK: ClassVar[int] - CODEC_NONE: ClassVar[int] - CODEC_GZIP: ClassVar[int] - CODEC_SNAPPY: ClassVar[int] - CODEC_LZ4: ClassVar[int] - CODEC_ZSTD: ClassVar[int] + CODEC_MASK: ClassVar[CodecMaskT] + CODEC_NONE: ClassVar[CodecNoneT] + CODEC_GZIP: ClassVar[CodecGzipT] + CODEC_SNAPPY: ClassVar[CodecSnappyT] + CODEC_LZ4: ClassVar[CodecLz4T] + CODEC_ZSTD: ClassVar[CodecZstdT] def __init__(self, buffer: Union[bytes, bytearray, memoryview]) -> None: ... @property @@ -161,7 +172,10 @@ def checksum(self) -> None: ... class LegacyRecordBatchBuilderProtocol(Protocol): def __init__( - self, magic: Literal[0, 1], compression_type: int, batch_size: int + self, + magic: Literal[0, 1], + compression_type: LegacyCompressionTypeT, + batch_size: int, ) -> None: ... def append( self, @@ -203,10 +217,10 @@ def timestamp(self) -> int: ... class LegacyRecordBatchProtocol(Iterable["LegacyRecordProtocol"], Protocol): - CODEC_MASK: ClassVar[int] - CODEC_GZIP: ClassVar[int] - CODEC_SNAPPY: ClassVar[int] - CODEC_LZ4: ClassVar[int] + CODEC_MASK: ClassVar[CodecMaskT] + CODEC_GZIP: ClassVar[CodecGzipT] + CODEC_SNAPPY: ClassVar[CodecSnappyT] + CODEC_LZ4: ClassVar[CodecLz4T] is_control_batch: bool is_transactional: bool diff --git a/aiokafka/record/default_records.py b/aiokafka/record/default_records.py index a210eb2e..2903b58a 100644 --- a/aiokafka/record/default_records.py +++ b/aiokafka/record/default_records.py @@ -58,7 +58,7 @@ import time from typing import Any, Callable, List, Optional, Sized, Tuple, Type, Union, final -from typing_extensions import Self +from typing_extensions import Self, TypeIs, assert_never import aiokafka.codec as codecs from aiokafka.codec import ( @@ -80,6 +80,14 @@ DefaultRecordMetadataProtocol, DefaultRecordProtocol, ) +from ._types import ( + CodecGzipT, + CodecLz4T, + CodecMaskT, + CodecNoneT, + CodecSnappyT, + CodecZstdT, +) from .util import calc_crc32c, decode_varint, encode_varint, size_of_varint @@ -106,12 +114,12 @@ class DefaultRecordBase: CRC_OFFSET = struct.calcsize(">qiib") AFTER_LEN_OFFSET = struct.calcsize(">qi") - CODEC_MASK = 0x07 - CODEC_NONE = 0x00 - CODEC_GZIP = 0x01 - CODEC_SNAPPY = 0x02 - CODEC_LZ4 = 0x03 - CODEC_ZSTD = 0x04 + CODEC_MASK: CodecMaskT = 0x07 + CODEC_NONE: CodecNoneT = 0x00 + CODEC_GZIP: CodecGzipT = 0x01 + CODEC_SNAPPY: CodecSnappyT = 0x02 + CODEC_LZ4: CodecLz4T = 0x03 + CODEC_ZSTD: CodecZstdT = 0x04 TIMESTAMP_TYPE_MASK = 0x08 TRANSACTIONAL_MASK = 0x10 CONTROL_MASK = 0x20 @@ -121,7 +129,9 @@ class DefaultRecordBase: NO_PARTITION_LEADER_EPOCH = -1 - def _assert_has_codec(self, compression_type: int) -> None: + def _assert_has_codec( + self, compression_type: int + ) -> TypeIs[Union[CodecGzipT, CodecSnappyT, CodecLz4T, CodecZstdT]]: if compression_type == self.CODEC_GZIP: checker, name = codecs.has_gzip, "gzip" elif compression_type == self.CODEC_SNAPPY: @@ -138,6 +148,7 @@ def _assert_has_codec(self, compression_type: int) -> None: raise UnsupportedCodecError( f"Libraries for {name} compression codec not found" ) + return True @final @@ -216,7 +227,7 @@ def _maybe_uncompress(self) -> None: if not self._decompressed: compression_type = self.compression_type if compression_type != self.CODEC_NONE: - self._assert_has_codec(compression_type) + assert self._assert_has_codec(compression_type) data = memoryview(self._buffer)[self._pos :] if compression_type == self.CODEC_GZIP: uncompressed = gzip_decode(data) @@ -224,8 +235,10 @@ def _maybe_uncompress(self) -> None: uncompressed = snappy_decode(data.tobytes()) elif compression_type == self.CODEC_LZ4: uncompressed = lz4_decode(data.tobytes()) - if compression_type == self.CODEC_ZSTD: + elif compression_type == self.CODEC_ZSTD: uncompressed = zstd_decode(data.tobytes()) + else: + assert_never(compression_type) self._buffer = bytearray(uncompressed) self._pos = 0 self._decompressed = True @@ -581,7 +594,7 @@ def _write_header(self, use_compression_type: bool = True) -> None: def _maybe_compress(self) -> bool: if self._compression_type != self.CODEC_NONE: - self._assert_has_codec(self._compression_type) + assert self._assert_has_codec(self._compression_type) header_size = self.HEADER_STRUCT.size data = bytes(self._buffer[header_size:]) if self._compression_type == self.CODEC_GZIP: @@ -592,6 +605,8 @@ def _maybe_compress(self) -> bool: compressed = lz4_encode(data) elif self._compression_type == self.CODEC_ZSTD: compressed = zstd_encode(data) + else: + assert_never(self._compression_type) compressed_size = len(compressed) if len(data) <= compressed_size: # We did not get any benefit from compression, lets send diff --git a/aiokafka/record/legacy_records.py b/aiokafka/record/legacy_records.py index f7e1e804..ea694e77 100644 --- a/aiokafka/record/legacy_records.py +++ b/aiokafka/record/legacy_records.py @@ -5,7 +5,7 @@ from binascii import crc32 from typing import Any, Generator, List, Optional, Tuple, Type, Union, final -from typing_extensions import Literal, Never +from typing_extensions import Literal, Never, TypeIs, assert_never import aiokafka.codec as codecs from aiokafka.codec import ( @@ -25,6 +25,13 @@ LegacyRecordMetadataProtocol, LegacyRecordProtocol, ) +from ._types import ( + CodecGzipT, + CodecLz4T, + CodecMaskT, + CodecSnappyT, + LegacyCompressionTypeT, +) NoneType = type(None) @@ -78,16 +85,18 @@ class LegacyRecordBase: KEY_OFFSET_V1 = HEADER_STRUCT_V1.size KEY_LENGTH = VALUE_LENGTH = struct.calcsize(">i") # Bytes length is Int32 - CODEC_MASK = 0x07 - CODEC_GZIP = 0x01 - CODEC_SNAPPY = 0x02 - CODEC_LZ4 = 0x03 + CODEC_MASK: CodecMaskT = 0x07 + CODEC_GZIP: CodecGzipT = 0x01 + CODEC_SNAPPY: CodecSnappyT = 0x02 + CODEC_LZ4: CodecLz4T = 0x03 TIMESTAMP_TYPE_MASK = 0x08 LOG_APPEND_TIME = 1 CREATE_TIME = 0 - def _assert_has_codec(self, compression_type: int) -> None: + def _assert_has_codec( + self, compression_type: int + ) -> TypeIs[Union[CodecGzipT, CodecSnappyT, CodecLz4T]]: if compression_type == self.CODEC_GZIP: checker, name = codecs.has_gzip, "gzip" elif compression_type == self.CODEC_SNAPPY: @@ -102,6 +111,7 @@ def _assert_has_codec(self, compression_type: int) -> None: raise UnsupportedCodecError( f"Libraries for {name} compression codec not found" ) + return True @final @@ -165,7 +175,7 @@ def _decompress(self, key_offset: int) -> bytes: data = self._buffer[pos : pos + value_size] compression_type = self._compression_type - self._assert_has_codec(compression_type) + assert self._assert_has_codec(compression_type) if compression_type == self.CODEC_GZIP: uncompressed = gzip_decode(data) elif compression_type == self.CODEC_SNAPPY: @@ -178,6 +188,8 @@ def _decompress(self, key_offset: int) -> bytes: ) else: uncompressed = lz4_decode(data.tobytes()) + else: + assert_never(compression_type) return uncompressed def _read_header(self, pos: int) -> Tuple[int, int, int, int, int, Optional[int]]: @@ -341,7 +353,10 @@ class _LegacyRecordBatchBuilderPy(LegacyRecordBase, LegacyRecordBatchBuilderProt _buffer: Optional[bytearray] = None def __init__( - self, magic: Literal[0, 1], compression_type: int, batch_size: int + self, + magic: Literal[0, 1], + compression_type: LegacyCompressionTypeT, + batch_size: int, ) -> None: assert magic in [0, 1] self._magic = magic @@ -478,7 +493,7 @@ def _encode_msg( def _maybe_compress(self) -> bool: if self._compression_type: assert self._buffer is not None - self._assert_has_codec(self._compression_type) + assert self._assert_has_codec(self._compression_type) buf = self._buffer if self._compression_type == self.CODEC_GZIP: compressed = gzip_encode(buf) @@ -492,6 +507,9 @@ def _maybe_compress(self) -> bool: ) else: compressed = lz4_encode(bytes(buf)) + + else: + assert_never(self._compression_type) compressed_size = len(compressed) size = self._size_in_bytes(key_size=0, value_size=compressed_size) if size > len(self._buffer): diff --git a/pyproject.toml b/pyproject.toml index 534987eb..8121a35b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dynamic = ["version"] dependencies = [ "async-timeout", "packaging", - "typing_extensions >=4.6.0", + "typing_extensions >=4.10.0", ] [project.optional-dependencies] diff --git a/requirements-ci.txt b/requirements-ci.txt index 69136120..08592f03 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,6 +1,6 @@ -r requirements-cython.txt ruff==0.3.4 -mypy==1.9.0 +mypy==1.10.0 pytest==7.4.3 pytest-cov==4.1.0 pytest-asyncio==0.21.1 diff --git a/requirements-win-test.txt b/requirements-win-test.txt index fab457ce..6a781e65 100644 --- a/requirements-win-test.txt +++ b/requirements-win-test.txt @@ -1,6 +1,6 @@ -r requirements-cython.txt ruff==0.3.2 -mypy==1.9.0 +mypy==1.10.0 pytest==7.4.3 pytest-cov==4.1.0 pytest-asyncio==0.21.1 diff --git a/tests/record/test_default_records.py b/tests/record/test_default_records.py index 9e8e9764..74d893d0 100644 --- a/tests/record/test_default_records.py +++ b/tests/record/test_default_records.py @@ -2,6 +2,7 @@ from unittest import mock import pytest +from typing_extensions import Literal import aiokafka.codec from aiokafka.errors import UnsupportedCodecError @@ -27,7 +28,9 @@ pytest.param(DefaultRecordBatch.CODEC_ZSTD, 1714138923, id="zstd"), ], ) -def test_read_write_serde_v2(compression_type: int, crc: int) -> None: +def test_read_write_serde_v2( + compression_type: Literal[0x00, 0x01, 0x02, 0x03, 0x04], crc: int +) -> None: builder = DefaultRecordBatchBuilder( magic=2, compression_type=compression_type, @@ -240,7 +243,9 @@ def test_default_batch_size_limit() -> None: (DefaultRecordBatch.CODEC_ZSTD, "zstd", "has_zstd"), ], ) -def test_unavailable_codec(compression_type: int, name: str, checker_name: str) -> None: +def test_unavailable_codec( + compression_type: Literal[0x01, 0x02, 0x03, 0x04], name: str, checker_name: str +) -> None: builder = DefaultRecordBatchBuilder( magic=2, compression_type=compression_type, @@ -279,7 +284,7 @@ def test_unsupported_yet_codec() -> None: compression_type = DefaultRecordBatch.CODEC_MASK # It doesn't exist builder = DefaultRecordBatchBuilder( magic=2, - compression_type=compression_type, + compression_type=compression_type, # type: ignore[arg-type] is_transactional=0, producer_id=-1, producer_epoch=-1, diff --git a/tests/record/test_legacy.py b/tests/record/test_legacy.py index 62a636ff..9faa3bc0 100644 --- a/tests/record/test_legacy.py +++ b/tests/record/test_legacy.py @@ -65,7 +65,7 @@ def test_read_write_serde_v0_v1_no_compression( ], ) def test_read_write_serde_v0_v1_with_compression( - compression_type: int, magic: Literal[0, 1] + compression_type: Literal[0x01, 0x02, 0x03], magic: Literal[0, 1] ) -> None: builder = LegacyRecordBatchBuilder( magic=magic, compression_type=compression_type, batch_size=1024 * 1024 @@ -226,7 +226,9 @@ def test_legacy_batch_size_limit(magic: Literal[0, 1]) -> None: (LegacyRecordBatch.CODEC_SNAPPY, "snappy", "has_snappy"), ], ) -def test_unavailable_codec(compression_type: int, name: str, checker_name: str) -> None: +def test_unavailable_codec( + compression_type: Literal[0x01, 0x02], name: str, checker_name: str +) -> None: builder = LegacyRecordBatchBuilder( magic=0, compression_type=compression_type, batch_size=1024 ) @@ -253,7 +255,9 @@ def test_unavailable_codec(compression_type: int, name: str, checker_name: str) def test_unsupported_yet_codec() -> None: compression_type = LegacyRecordBatch.CODEC_MASK # It doesn't exist builder = LegacyRecordBatchBuilder( - magic=0, compression_type=compression_type, batch_size=1024 + magic=0, + compression_type=compression_type, # type: ignore[arg-type] + batch_size=1024, ) with pytest.raises(UnsupportedCodecError): builder.append(0, timestamp=None, key=None, value=b"M")