From 9a060fd626bd3bc2c1e6455248b11fef425c61b2 Mon Sep 17 00:00:00 2001 From: 0x26res Date: Mon, 27 Nov 2023 09:42:38 +0100 Subject: [PATCH] Add some arrow replay code (#49) --- CHANGELOG.md | 8 +++++ beavers/pyarrow_replay.py | 54 +++++++++++++++++++++++++++++++ tests/test_pyarrow_replay.py | 61 ++++++++++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+) create mode 100644 beavers/pyarrow_replay.py create mode 100644 tests/test_pyarrow_replay.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b7098c..c2c4fb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). +## [v0.4.0](https://github.com/tradewelltech/beavers/releases/tag/v0.4.0) - 2023-11-26 + +[Compare with v0.3.1](https://github.com/tradewelltech/beavers/compare/v0.3.1...v0.4.0) + +### Added + +- Add some arrow replay code ([d8026ec](https://github.com/tradewelltech/beavers/commit/d8026ecf744886b0bb7406814904adb3308ba0b9) by 0x26res). + ## [v0.3.1](https://github.com/tradewelltech/beavers/releases/tag/v0.3.1) - 2023-10-26 [Compare with v0.3.0](https://github.com/tradewelltech/beavers/compare/v0.3.0...v0.3.1) diff --git a/beavers/pyarrow_replay.py b/beavers/pyarrow_replay.py new file mode 100644 index 0000000..5052bdf --- /dev/null +++ b/beavers/pyarrow_replay.py @@ -0,0 +1,54 @@ +import dataclasses +from typing import Callable + +import pandas as pd +import pyarrow as pa + +from beavers.engine import UTC_MAX +from beavers.replay import DataSink, DataSource + + +class ArrowTableDataSource(DataSource[pa.Table]): + def __init__( + self, table: pa.Table, timestamp_extractor: Callable[[pa.Table], pa.Array] + ): + assert callable(timestamp_extractor) + self._table = table + self._empty_table = table.schema.empty_table() + self._timestamp_column = timestamp_extractor(table).to_pandas( + date_as_object=False + ) + assert ( + self._timestamp_column.is_monotonic_increasing + ), "Timestamp column should be monotonic increasing" + self._index = 0 + + def read_to(self, timestamp: pd.Timestamp) -> pa.Table: + new_index = self._timestamp_column.searchsorted(timestamp, side="right") + if new_index > self._index: + from_index = self._index + self._index = new_index + return self._table.slice(from_index, new_index - from_index) + else: + results = self._empty_table + return results + + def get_next(self) -> pd.Timestamp: + if self._index >= len(self._table): + return UTC_MAX + else: + return self._timestamp_column.iloc[self._index] + + +@dataclasses.dataclass +class ArrowTableDataSink(DataSink[pa.Table]): + saver: Callable[[pa.Table], None] + chunks: list[pa.Table] = dataclasses.field(default_factory=list) + + def append(self, timestamp: pd.Timestamp, data: pa.Table): + self.chunks.append(data) + + def close(self): + if self.chunks: + results = pa.concat_tables(self.chunks) + self.saver(results) diff --git a/tests/test_pyarrow_replay.py b/tests/test_pyarrow_replay.py new file mode 100644 index 0000000..4a8f18b --- /dev/null +++ b/tests/test_pyarrow_replay.py @@ -0,0 +1,61 @@ +from operator import itemgetter + +import pandas as pd +import pyarrow as pa +import pyarrow.csv +import pytest + +from beavers.engine import UTC_MAX +from beavers.pyarrow_replay import ArrowTableDataSink, ArrowTableDataSource + +TEST_TABLE = pa.table( + { + "timestamp": [ + pd.to_datetime("2023-01-01T00:00:00Z"), + pd.to_datetime("2023-01-02T00:00:00Z"), + ], + "value": [1, 2], + } +) + + +def test_arrow_table_data_source(): + source = ArrowTableDataSource(TEST_TABLE, itemgetter("timestamp")) + + assert source.get_next() == pd.to_datetime("2023-01-01T00:00:00Z") + assert source.read_to(pd.to_datetime("2023-01-01T00:00:00Z")) == TEST_TABLE[:1] + assert source.read_to(pd.to_datetime("2023-01-01T00:00:00Z")) == TEST_TABLE[:0] + assert source.get_next() == pd.to_datetime("2023-01-02T00:00:00Z") + assert source.read_to(pd.to_datetime("2023-01-02T00:00:00Z")) == TEST_TABLE[1:] + assert source.get_next() == UTC_MAX + assert source.read_to(UTC_MAX) == TEST_TABLE[:0] + + +def test_arrow_table_data_source_ooo(): + with pytest.raises( + AssertionError, match="Timestamp column should be monotonic increasing" + ): + ArrowTableDataSource( + pa.table( + { + "timestamp": [ + pd.to_datetime("2023-01-02T00:00:00Z"), + pd.to_datetime("2023-01-01T00:00:00Z"), + ], + "value": [1, 2], + } + ), + itemgetter("timestamp"), + ) + + +def test_arrow_table_data_sink(tmpdir): + file = tmpdir / "file.csv" + sink = ArrowTableDataSink(lambda table: pyarrow.csv.write_csv(table, file)) + + sink.close() + assert not file.exists() + + sink.append(UTC_MAX, TEST_TABLE) + sink.close() + assert file.exists()