Skip to content

Commit

Permalink
Add some arrow replay code (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x26res authored Nov 27, 2023
1 parent a2e6b9a commit 9a060fd
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 0 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

<!-- insertion marker -->
## [v0.4.0](https://github.com/tradewelltech/beavers/releases/tag/v0.4.0) - 2023-11-26

<small>[Compare with v0.3.1](https://github.com/tradewelltech/beavers/compare/v0.3.1...v0.4.0)</small>

### Added

- Add some arrow replay code ([d8026ec](https://github.com/tradewelltech/beavers/commit/d8026ecf744886b0bb7406814904adb3308ba0b9) by 0x26res).

## [v0.3.1](https://github.com/tradewelltech/beavers/releases/tag/v0.3.1) - 2023-10-26

<small>[Compare with v0.3.0](https://github.com/tradewelltech/beavers/compare/v0.3.0...v0.3.1)</small>
Expand Down
54 changes: 54 additions & 0 deletions beavers/pyarrow_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import dataclasses
from typing import Callable

import pandas as pd
import pyarrow as pa

from beavers.engine import UTC_MAX
from beavers.replay import DataSink, DataSource


class ArrowTableDataSource(DataSource[pa.Table]):
def __init__(
self, table: pa.Table, timestamp_extractor: Callable[[pa.Table], pa.Array]
):
assert callable(timestamp_extractor)
self._table = table
self._empty_table = table.schema.empty_table()
self._timestamp_column = timestamp_extractor(table).to_pandas(
date_as_object=False
)
assert (
self._timestamp_column.is_monotonic_increasing
), "Timestamp column should be monotonic increasing"
self._index = 0

def read_to(self, timestamp: pd.Timestamp) -> pa.Table:
new_index = self._timestamp_column.searchsorted(timestamp, side="right")
if new_index > self._index:
from_index = self._index
self._index = new_index
return self._table.slice(from_index, new_index - from_index)
else:
results = self._empty_table
return results

def get_next(self) -> pd.Timestamp:
if self._index >= len(self._table):
return UTC_MAX
else:
return self._timestamp_column.iloc[self._index]


@dataclasses.dataclass
class ArrowTableDataSink(DataSink[pa.Table]):
saver: Callable[[pa.Table], None]
chunks: list[pa.Table] = dataclasses.field(default_factory=list)

def append(self, timestamp: pd.Timestamp, data: pa.Table):
self.chunks.append(data)

def close(self):
if self.chunks:
results = pa.concat_tables(self.chunks)
self.saver(results)
61 changes: 61 additions & 0 deletions tests/test_pyarrow_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from operator import itemgetter

import pandas as pd
import pyarrow as pa
import pyarrow.csv
import pytest

from beavers.engine import UTC_MAX
from beavers.pyarrow_replay import ArrowTableDataSink, ArrowTableDataSource

TEST_TABLE = pa.table(
{
"timestamp": [
pd.to_datetime("2023-01-01T00:00:00Z"),
pd.to_datetime("2023-01-02T00:00:00Z"),
],
"value": [1, 2],
}
)


def test_arrow_table_data_source():
source = ArrowTableDataSource(TEST_TABLE, itemgetter("timestamp"))

assert source.get_next() == pd.to_datetime("2023-01-01T00:00:00Z")
assert source.read_to(pd.to_datetime("2023-01-01T00:00:00Z")) == TEST_TABLE[:1]
assert source.read_to(pd.to_datetime("2023-01-01T00:00:00Z")) == TEST_TABLE[:0]
assert source.get_next() == pd.to_datetime("2023-01-02T00:00:00Z")
assert source.read_to(pd.to_datetime("2023-01-02T00:00:00Z")) == TEST_TABLE[1:]
assert source.get_next() == UTC_MAX
assert source.read_to(UTC_MAX) == TEST_TABLE[:0]


def test_arrow_table_data_source_ooo():
with pytest.raises(
AssertionError, match="Timestamp column should be monotonic increasing"
):
ArrowTableDataSource(
pa.table(
{
"timestamp": [
pd.to_datetime("2023-01-02T00:00:00Z"),
pd.to_datetime("2023-01-01T00:00:00Z"),
],
"value": [1, 2],
}
),
itemgetter("timestamp"),
)


def test_arrow_table_data_sink(tmpdir):
file = tmpdir / "file.csv"
sink = ArrowTableDataSink(lambda table: pyarrow.csv.write_csv(table, file))

sink.close()
assert not file.exists()

sink.append(UTC_MAX, TEST_TABLE)
sink.close()
assert file.exists()

0 comments on commit 9a060fd

Please sign in to comment.