diff --git a/beavers/kafka.py b/beavers/kafka.py index 4dc7f10..e967554 100644 --- a/beavers/kafka.py +++ b/beavers/kafka.py @@ -4,7 +4,7 @@ import logging import time from enum import Enum -from typing import Any, AnyStr, Generic, Optional, Protocol, TypeVar +from typing import Any, AnyStr, Generic, Optional, Protocol, Sequence, TypeVar import confluent_kafka import confluent_kafka.admin @@ -22,7 +22,7 @@ class KafkaMessageDeserializer(Protocol[T]): """Interface for converting incoming kafka messages to custom data.""" - def __call__(self, messages: list[confluent_kafka.Message]) -> T: + def __call__(self, messages: Sequence[confluent_kafka.Message]) -> T: """Convert batch of messages to data.""" @@ -38,7 +38,7 @@ class KafkaProducerMessage: class KafkaMessageSerializer(Protocol[T]): """Interface for converting custom data to outgoing kafka messages.""" - def __call__(self, value: T) -> list[KafkaProducerMessage]: + def __call__(self, value: T) -> Sequence[KafkaProducerMessage]: """Convert batch of custom data to `KafkaProducerMessage`.""" @@ -384,29 +384,14 @@ def _update_partition_info(self, new_messages: list[confluent_kafka.Message]): ) -@dataclasses.dataclass(frozen=True) -class _RuntimeSinkTopic: - nodes: list[Node] - serializer: KafkaMessageSerializer - - def flush(self, cycle_id: int, producer_manger: _ProducerManager): - for node in self.nodes: - if node.get_cycle_id() == cycle_id: - node_value = node.get_sink_value() - # TODO: capture serialization time in metrics - messages = self.serializer(node_value) - for message in messages: - producer_manger.produce_one( - message.topic, message.key, message.value - ) - - @dataclasses.dataclass class ExecutionMetrics: """Metrics for the execution of a dag.""" serialization_ns: int = 0 serialization_count: int = 0 + deserialization_ns: int = 0 + deserialization_count: int = 0 execution_ns: int = 0 execution_count: int = 0 @@ -419,6 +404,15 @@ def measure_serialization_time(self): self.serialization_ns += time.time_ns() - before self.serialization_count += 1 + @contextlib.contextmanager + def measure_deserialization_time(self): + before = time.time_ns() + try: + yield + finally: + self.deserialization_ns += time.time_ns() - before + self.deserialization_count += 1 + @contextlib.contextmanager def measure_execution_time(self): before = time.time_ns() @@ -429,6 +423,20 @@ def measure_execution_time(self): self.execution_count += 1 +@dataclasses.dataclass(frozen=True) +class _RuntimeSinkTopic: + nodes: list[Node] + serializer: KafkaMessageSerializer + + def serialize(self, cycle_id: int) -> list[KafkaProducerMessage]: + messages = [] + for node in self.nodes: + if node.get_cycle_id() == cycle_id: + node_value = node.get_sink_value() + messages.extend(self.serializer(node_value)) + return messages + + class KafkaDriver: """Control the execution of a dag, using data from kafka.""" @@ -519,13 +527,19 @@ def _process_message(self, message: confluent_kafka.Message): self._source_topics[message.topic()].append(message) def _produce_records(self, cycle_id: int): - for sink_topic in self._sink_topics: - sink_topic.flush(cycle_id, self._producer_manager) + messages = [] + with self._metrics.measure_serialization_time(): + for sink_topic in self._sink_topics: + messages.extend(sink_topic.serialize(cycle_id)) + for message in messages: + self._producer_manager.produce_one( + message.topic, message.key, message.value + ) def _run_cycle(self, messages: list[confluent_kafka.Message]) -> bool: has_messages = False - with self._metrics.measure_serialization_time(): - self._process_messages(messages) + self._process_messages(messages) + with self._metrics.measure_deserialization_time(): for handler in self._source_topics.values(): has_messages = handler.flush() or has_messages cycle_time = ( @@ -652,7 +666,6 @@ def _get_previous_start_of_day( if (local_now - local_now.normalize()) > start_of_day_time: return (local_now.normalize() + start_of_day_time).tz_convert("UTC") else: - # TODO: consider adding calendar? return ( local_now.normalize() - pd.to_timedelta("1d") + start_of_day_time ).tz_convert("UTC") diff --git a/tests/test_kafka.py b/tests/test_kafka.py index f0f6942..6ab0ad4 100644 --- a/tests/test_kafka.py +++ b/tests/test_kafka.py @@ -572,8 +572,10 @@ def test_kafka_driver_word_count(log_helper: LogHelper): assert len(log_helper.flush()) == 1 metrics = kafka_driver.flush_metrics() + assert metrics.deserialization_ns > 0 + assert metrics.deserialization_count == 6 assert metrics.serialization_ns > 0 - assert metrics.serialization_count == 6 + assert metrics.serialization_count == 3 assert metrics.execution_ns > 0 assert metrics.execution_count == 3 @@ -1308,19 +1310,14 @@ def test_runtime_sink_topic(): sink = dag.sink("sink", node) runtime_sink_topic = _RuntimeSinkTopic([sink], WorldCountSerializer("topic-1")) - producer_manager = MockProducerManager() dag.execute() - runtime_sink_topic.flush(dag.get_cycle_id(), producer_manager) - assert producer_manager.messages == [] + assert runtime_sink_topic.serialize(dag.get_cycle_id()) == [] node.set_stream({"foo": "bar"}) dag.execute() - runtime_sink_topic.flush(dag.get_cycle_id(), producer_manager) - assert producer_manager.messages == [ + assert runtime_sink_topic.serialize(dag.get_cycle_id()) == [ KafkaProducerMessage(topic="topic-1", key=b"foo", value=b"bar") ] - producer_manager.messages.clear() dag.execute() - runtime_sink_topic.flush(dag.get_cycle_id(), producer_manager) - assert producer_manager.messages == [] + assert runtime_sink_topic.serialize(dag.get_cycle_id()) == []