Skip to content

Commit

Permalink
use TypeIs
Browse files Browse the repository at this point in the history
  • Loading branch information
dimastbk committed Apr 25, 2024
1 parent 58bc26c commit 686e4f2
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 57 deletions.
23 changes: 16 additions & 7 deletions aiokafka/record/_crecords/default_records.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ from aiokafka.record._protocols import (
DefaultRecordMetadataProtocol,
DefaultRecordProtocol,
)
from aiokafka.record._types import (
CodecGzipT,
CodecLz4T,
CodecMaskT,
CodecNoneT,
CodecSnappyT,
CodecZstdT,
DefaultCompressionTypeT,
)

@final
class DefaultRecord(DefaultRecordProtocol):
Expand All @@ -33,12 +42,12 @@ class DefaultRecord(DefaultRecordProtocol):

@final
class DefaultRecordBatch(DefaultRecordBatchProtocol):
CODEC_NONE: ClassVar[int]
CODEC_MASK: ClassVar[int]
CODEC_GZIP: ClassVar[int]
CODEC_SNAPPY: ClassVar[int]
CODEC_LZ4: ClassVar[int]
CODEC_ZSTD: ClassVar[int]
CODEC_MASK: ClassVar[CodecMaskT]
CODEC_NONE: ClassVar[CodecNoneT]
CODEC_GZIP: ClassVar[CodecGzipT]
CODEC_SNAPPY: ClassVar[CodecSnappyT]
CODEC_LZ4: ClassVar[CodecLz4T]
CODEC_ZSTD: ClassVar[CodecZstdT]

base_offset: int
length: int
Expand Down Expand Up @@ -75,7 +84,7 @@ class DefaultRecordBatchBuilder(DefaultRecordBatchBuilderProtocol):
def __init__(
self,
magic: int,
compression_type: int,
compression_type: DefaultCompressionTypeT,
is_transactional: int,
producer_id: int,
producer_epoch: int,
Expand Down
27 changes: 18 additions & 9 deletions aiokafka/record/_crecords/legacy_records.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ from aiokafka.record._protocols import (
LegacyRecordMetadataProtocol,
LegacyRecordProtocol,
)
from aiokafka.record._types import (
CodecGzipT,
CodecLz4T,
CodecMaskT,
CodecSnappyT,
LegacyCompressionTypeT,
)

@final
class LegacyRecord(LegacyRecordProtocol):
Expand Down Expand Up @@ -38,10 +45,10 @@ class LegacyRecord(LegacyRecordProtocol):
class LegacyRecordBatch(LegacyRecordBatchProtocol):
RECORD_OVERHEAD_V0: ClassVar[int]
RECORD_OVERHEAD_V1: ClassVar[int]
CODEC_MASK: ClassVar[int]
CODEC_GZIP: ClassVar[int]
CODEC_SNAPPY: ClassVar[int]
CODEC_LZ4: ClassVar[int]
CODEC_MASK: ClassVar[CodecMaskT]
CODEC_GZIP: ClassVar[CodecGzipT]
CODEC_SNAPPY: ClassVar[CodecSnappyT]
CODEC_LZ4: ClassVar[CodecLz4T]

is_control_batch: bool
is_transactional: bool
Expand All @@ -54,12 +61,14 @@ class LegacyRecordBatch(LegacyRecordBatchProtocol):

@final
class LegacyRecordBatchBuilder(LegacyRecordBatchBuilderProtocol):
CODEC_MASK: ClassVar[int]
CODEC_GZIP: ClassVar[int]
CODEC_SNAPPY: ClassVar[int]
CODEC_LZ4: ClassVar[int]
CODEC_MASK: ClassVar[CodecMaskT]
CODEC_GZIP: ClassVar[CodecGzipT]
CODEC_SNAPPY: ClassVar[CodecSnappyT]
CODEC_LZ4: ClassVar[CodecLz4T]

def __init__(self, magic: int, compression_type: int, batch_size: int) -> None: ...
def __init__(
self, magic: int, compression_type: LegacyCompressionTypeT, batch_size: int
) -> None: ...
def append(
self,
offset: int,
Expand Down
38 changes: 26 additions & 12 deletions aiokafka/record/_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,23 @@

from typing_extensions import Literal, Never

from ._types import (
CodecGzipT,
CodecLz4T,
CodecMaskT,
CodecNoneT,
CodecSnappyT,
CodecZstdT,
DefaultCompressionTypeT,
LegacyCompressionTypeT,
)


class DefaultRecordBatchBuilderProtocol(Protocol):
def __init__(
self,
magic: int,
compression_type: int,
compression_type: DefaultCompressionTypeT,
is_transactional: int,
producer_id: int,
producer_epoch: int,
Expand Down Expand Up @@ -83,12 +94,12 @@ def timestamp(self) -> int: ...


class DefaultRecordBatchProtocol(Iterator["DefaultRecordProtocol"], Protocol):
CODEC_MASK: ClassVar[int]
CODEC_NONE: ClassVar[int]
CODEC_GZIP: ClassVar[int]
CODEC_SNAPPY: ClassVar[int]
CODEC_LZ4: ClassVar[int]
CODEC_ZSTD: ClassVar[int]
CODEC_MASK: ClassVar[CodecMaskT]
CODEC_NONE: ClassVar[CodecNoneT]
CODEC_GZIP: ClassVar[CodecGzipT]
CODEC_SNAPPY: ClassVar[CodecSnappyT]
CODEC_LZ4: ClassVar[CodecLz4T]
CODEC_ZSTD: ClassVar[CodecZstdT]

def __init__(self, buffer: Union[bytes, bytearray, memoryview]) -> None: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
@property
Expand Down Expand Up @@ -161,7 +172,10 @@ def checksum(self) -> None: ...

class LegacyRecordBatchBuilderProtocol(Protocol):
def __init__(
self, magic: Literal[0, 1], compression_type: int, batch_size: int
self,
magic: Literal[0, 1],
compression_type: LegacyCompressionTypeT,
batch_size: int,
) -> None: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
def append(
self,
Expand Down Expand Up @@ -203,10 +217,10 @@ def timestamp(self) -> int: ...


class LegacyRecordBatchProtocol(Iterable["LegacyRecordProtocol"], Protocol):
CODEC_MASK: ClassVar[int]
CODEC_GZIP: ClassVar[int]
CODEC_SNAPPY: ClassVar[int]
CODEC_LZ4: ClassVar[int]
CODEC_MASK: ClassVar[CodecMaskT]
CODEC_GZIP: ClassVar[CodecGzipT]
CODEC_SNAPPY: ClassVar[CodecSnappyT]
CODEC_LZ4: ClassVar[CodecLz4T]

is_control_batch: bool
is_transactional: bool
Expand Down
14 changes: 14 additions & 0 deletions aiokafka/record/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Union

from typing_extensions import Literal

CodecNoneT = Literal[0x00]
CodecGzipT = Literal[0x01]
CodecSnappyT = Literal[0x02]
CodecLz4T = Literal[0x03]
CodecZstdT = Literal[0x04]
CodecMaskT = Literal[0x07]
DefaultCompressionTypeT = Union[
CodecGzipT, CodecLz4T, CodecNoneT, CodecSnappyT, CodecZstdT
]
LegacyCompressionTypeT = Union[CodecGzipT, CodecLz4T, CodecSnappyT, CodecNoneT]
37 changes: 26 additions & 11 deletions aiokafka/record/default_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
import time
from typing import Any, Callable, List, Optional, Sized, Tuple, Type, Union, final

from typing_extensions import Self
from typing_extensions import Self, TypeIs, assert_never

import aiokafka.codec as codecs
from aiokafka.codec import (
Expand All @@ -80,6 +80,14 @@
DefaultRecordMetadataProtocol,
DefaultRecordProtocol,
)
from ._types import (
CodecGzipT,
CodecLz4T,
CodecMaskT,
CodecNoneT,
CodecSnappyT,
CodecZstdT,
)
from .util import calc_crc32c, decode_varint, encode_varint, size_of_varint


Expand All @@ -106,12 +114,12 @@ class DefaultRecordBase:
CRC_OFFSET = struct.calcsize(">qiib")
AFTER_LEN_OFFSET = struct.calcsize(">qi")

CODEC_MASK = 0x07
CODEC_NONE = 0x00
CODEC_GZIP = 0x01
CODEC_SNAPPY = 0x02
CODEC_LZ4 = 0x03
CODEC_ZSTD = 0x04
CODEC_MASK: CodecMaskT = 0x07
CODEC_NONE: CodecNoneT = 0x00
CODEC_GZIP: CodecGzipT = 0x01
CODEC_SNAPPY: CodecSnappyT = 0x02
CODEC_LZ4: CodecLz4T = 0x03
CODEC_ZSTD: CodecZstdT = 0x04
TIMESTAMP_TYPE_MASK = 0x08
TRANSACTIONAL_MASK = 0x10
CONTROL_MASK = 0x20
Expand All @@ -121,7 +129,9 @@ class DefaultRecordBase:

NO_PARTITION_LEADER_EPOCH = -1

def _assert_has_codec(self, compression_type: int) -> None:
def _assert_has_codec(
self, compression_type: int
) -> TypeIs[Union[CodecGzipT, CodecSnappyT, CodecLz4T, CodecZstdT]]:
if compression_type == self.CODEC_GZIP:
checker, name = codecs.has_gzip, "gzip"
elif compression_type == self.CODEC_SNAPPY:
Expand All @@ -138,6 +148,7 @@ def _assert_has_codec(self, compression_type: int) -> None:
raise UnsupportedCodecError(
f"Libraries for {name} compression codec not found"
)
return True


@final
Expand Down Expand Up @@ -216,16 +227,18 @@ def _maybe_uncompress(self) -> None:
if not self._decompressed:
compression_type = self.compression_type
if compression_type != self.CODEC_NONE:
self._assert_has_codec(compression_type)
assert self._assert_has_codec(compression_type)
data = memoryview(self._buffer)[self._pos :]
if compression_type == self.CODEC_GZIP:
uncompressed = gzip_decode(data)
elif compression_type == self.CODEC_SNAPPY:
uncompressed = snappy_decode(data.tobytes())
elif compression_type == self.CODEC_LZ4:
uncompressed = lz4_decode(data.tobytes())
if compression_type == self.CODEC_ZSTD:
elif compression_type == self.CODEC_ZSTD:
uncompressed = zstd_decode(data.tobytes())
else:
assert_never(compression_type)
self._buffer = bytearray(uncompressed)
self._pos = 0
self._decompressed = True
Expand Down Expand Up @@ -581,7 +594,7 @@ def _write_header(self, use_compression_type: bool = True) -> None:

def _maybe_compress(self) -> bool:
if self._compression_type != self.CODEC_NONE:
self._assert_has_codec(self._compression_type)
assert self._assert_has_codec(self._compression_type)
header_size = self.HEADER_STRUCT.size
data = bytes(self._buffer[header_size:])
if self._compression_type == self.CODEC_GZIP:
Expand All @@ -592,6 +605,8 @@ def _maybe_compress(self) -> bool:
compressed = lz4_encode(data)
elif self._compression_type == self.CODEC_ZSTD:
compressed = zstd_encode(data)
else:
assert_never(self._compression_type)
compressed_size = len(compressed)
if len(data) <= compressed_size:
# We did not get any benefit from compression, lets send
Expand Down
36 changes: 27 additions & 9 deletions aiokafka/record/legacy_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from binascii import crc32
from typing import Any, Generator, List, Optional, Tuple, Type, Union, final

from typing_extensions import Literal, Never
from typing_extensions import Literal, Never, TypeIs, assert_never

import aiokafka.codec as codecs
from aiokafka.codec import (
Expand All @@ -25,6 +25,13 @@
LegacyRecordMetadataProtocol,
LegacyRecordProtocol,
)
from ._types import (
CodecGzipT,
CodecLz4T,
CodecMaskT,
CodecSnappyT,
LegacyCompressionTypeT,
)

NoneType = type(None)

Expand Down Expand Up @@ -78,16 +85,18 @@ class LegacyRecordBase:
KEY_OFFSET_V1 = HEADER_STRUCT_V1.size
KEY_LENGTH = VALUE_LENGTH = struct.calcsize(">i") # Bytes length is Int32

CODEC_MASK = 0x07
CODEC_GZIP = 0x01
CODEC_SNAPPY = 0x02
CODEC_LZ4 = 0x03
CODEC_MASK: CodecMaskT = 0x07
CODEC_GZIP: CodecGzipT = 0x01
CODEC_SNAPPY: CodecSnappyT = 0x02
CODEC_LZ4: CodecLz4T = 0x03
TIMESTAMP_TYPE_MASK = 0x08

LOG_APPEND_TIME = 1
CREATE_TIME = 0

def _assert_has_codec(self, compression_type: int) -> None:
def _assert_has_codec(
self, compression_type: int
) -> TypeIs[Union[CodecGzipT, CodecSnappyT, CodecLz4T]]:
if compression_type == self.CODEC_GZIP:
checker, name = codecs.has_gzip, "gzip"
elif compression_type == self.CODEC_SNAPPY:
Expand All @@ -102,6 +111,7 @@ def _assert_has_codec(self, compression_type: int) -> None:
raise UnsupportedCodecError(
f"Libraries for {name} compression codec not found"
)
return True


@final
Expand Down Expand Up @@ -165,7 +175,7 @@ def _decompress(self, key_offset: int) -> bytes:
data = self._buffer[pos : pos + value_size]

compression_type = self._compression_type
self._assert_has_codec(compression_type)
assert self._assert_has_codec(compression_type)
if compression_type == self.CODEC_GZIP:
uncompressed = gzip_decode(data)
elif compression_type == self.CODEC_SNAPPY:
Expand All @@ -178,6 +188,8 @@ def _decompress(self, key_offset: int) -> bytes:
)
else:
uncompressed = lz4_decode(data.tobytes())
else:
assert_never(compression_type)
return uncompressed

def _read_header(self, pos: int) -> Tuple[int, int, int, int, int, Optional[int]]:
Expand Down Expand Up @@ -341,7 +353,10 @@ class _LegacyRecordBatchBuilderPy(LegacyRecordBase, LegacyRecordBatchBuilderProt
_buffer: Optional[bytearray] = None

def __init__(
self, magic: Literal[0, 1], compression_type: int, batch_size: int
self,
magic: Literal[0, 1],
compression_type: LegacyCompressionTypeT,
batch_size: int,
) -> None:
assert magic in [0, 1]
self._magic = magic
Expand Down Expand Up @@ -478,7 +493,7 @@ def _encode_msg(
def _maybe_compress(self) -> bool:
if self._compression_type:
assert self._buffer is not None
self._assert_has_codec(self._compression_type)
assert self._assert_has_codec(self._compression_type)
buf = self._buffer
if self._compression_type == self.CODEC_GZIP:
compressed = gzip_encode(buf)
Expand All @@ -492,6 +507,9 @@ def _maybe_compress(self) -> bool:
)
else:
compressed = lz4_encode(bytes(buf))

else:
assert_never(self._compression_type)
compressed_size = len(compressed)
size = self._size_in_bytes(key_size=0, value_size=compressed_size)
if size > len(self._buffer):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dynamic = ["version"]
dependencies = [
"async-timeout",
"packaging",
"typing_extensions >=4.6.0",
"typing_extensions >=4.10.0",
]

[project.optional-dependencies]
Expand Down
Loading

0 comments on commit 686e4f2

Please sign in to comment.