From 120c116d13ab46604d54088bb07d851ff5d3fd00 Mon Sep 17 00:00:00 2001 From: 0x26res Date: Tue, 5 Dec 2023 16:40:24 +0100 Subject: [PATCH] Add kafka json to arrow support (#50) --- beavers/pyarrow_kafka.py | 49 ++++++++++++++++++++++++++++++++++++ tests/test_pyarrow_kafka.py | 18 +++++++++++++ tests/test_pyarrow_replay.py | 11 +------- tests/test_util.py | 11 ++++++++ 4 files changed, 79 insertions(+), 10 deletions(-) create mode 100644 beavers/pyarrow_kafka.py create mode 100644 tests/test_pyarrow_kafka.py diff --git a/beavers/pyarrow_kafka.py b/beavers/pyarrow_kafka.py new file mode 100644 index 0000000..dba31c0 --- /dev/null +++ b/beavers/pyarrow_kafka.py @@ -0,0 +1,49 @@ +import dataclasses +import io +import json + +import confluent_kafka +import pyarrow as pa +import pyarrow.json + +from beavers.kafka import ( + KafkaMessageDeserializer, + KafkaMessageSerializer, + KafkaProducerMessage, +) + + +@dataclasses.dataclass(frozen=True) +class JsonDeserializer(KafkaMessageDeserializer[pa.Table]): + schema: pa.Schema + + def __call__(self, messages: confluent_kafka.Message) -> pa.Table: + if messages: + with io.BytesIO() as buffer: + for message in messages: + buffer.write(message.value()) + buffer.write(b"\n") + buffer.seek(0) + return pyarrow.json.read_json( + buffer, + parse_options=pyarrow.json.ParseOptions( + explicit_schema=self.schema + ), + ) + else: + return self.schema.empty_table() + + +@dataclasses.dataclass(frozen=True) +class JsonSerializer(KafkaMessageSerializer[pa.Table]): + topic: str + + def __call__(self, table: pa.Table): + return [ + KafkaProducerMessage( + self.topic, + key=None, + value=json.dumps(message, default=str).encode("utf-8"), + ) + for message in table.to_pylist() + ] diff --git a/tests/test_pyarrow_kafka.py b/tests/test_pyarrow_kafka.py new file mode 100644 index 0000000..707a3fa --- /dev/null +++ b/tests/test_pyarrow_kafka.py @@ -0,0 +1,18 @@ +from beavers.pyarrow_kafka import JsonDeserializer, JsonSerializer +from tests.test_kafka import mock_kafka_message +from tests.test_util import TEST_TABLE + + +def test_json_deserializer_empty(): + deserializer = JsonDeserializer(TEST_TABLE.schema) + assert deserializer([]) == TEST_TABLE.schema.empty_table() + + +def test_end_to_end(): + deserializer = JsonDeserializer(TEST_TABLE.schema) + serializer = JsonSerializer("topic-1") + out_messages = serializer(TEST_TABLE) + in_messages = [ + mock_kafka_message(topic=m.topic, value=m.value) for m in out_messages + ] + assert deserializer(in_messages) == TEST_TABLE diff --git a/tests/test_pyarrow_replay.py b/tests/test_pyarrow_replay.py index 4a8f18b..bc50a42 100644 --- a/tests/test_pyarrow_replay.py +++ b/tests/test_pyarrow_replay.py @@ -7,16 +7,7 @@ from beavers.engine import UTC_MAX from beavers.pyarrow_replay import ArrowTableDataSink, ArrowTableDataSource - -TEST_TABLE = pa.table( - { - "timestamp": [ - pd.to_datetime("2023-01-01T00:00:00Z"), - pd.to_datetime("2023-01-02T00:00:00Z"), - ], - "value": [1, 2], - } -) +from tests.test_util import TEST_TABLE def test_arrow_table_data_source(): diff --git a/tests/test_util.py b/tests/test_util.py index 06d6e7d..3edf69d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,12 +4,23 @@ from typing import Callable, Dict, Generic, TypeVar import pandas as pd +import pyarrow as pa from beavers.engine import UTC_MAX, Dag, TimerManager from beavers.replay import DataSink, DataSource T = TypeVar("T") +TEST_TABLE = pa.table( + { + "timestamp": [ + pd.to_datetime("2023-01-01T00:00:00Z"), + pd.to_datetime("2023-01-02T00:00:00Z"), + ], + "value": [1, 2], + } +) + class GetLatest(Generic[T]): def __init__(self, default: T):