diff --git a/metrics/timescaledb/writer.py b/metrics/timescaledb/writer.py index 2b38e08d..085459b2 100644 --- a/metrics/timescaledb/writer.py +++ b/metrics/timescaledb/writer.py @@ -5,6 +5,8 @@ from sqlalchemy import create_engine, inspect, schema, text from sqlalchemy.dialects.postgresql import insert +from ..tools.iter import batched + log = structlog.get_logger() @@ -31,14 +33,13 @@ def ensure_table(engine, table): class TimescaleDBWriter: - inserts = [] - def __init__(self, table, engine=None): if engine is None: engine = create_engine(TIMESCALEDB_URL) self.engine = engine self.table = table + self.values = [] def __enter__(self): ensure_table(self.engine, self.table) @@ -46,30 +47,30 @@ def __enter__(self): return self def __exit__(self, *args): + # get the primary key name from the given table + constraint = inspect(self.engine).get_pk_constraint(self.table.name)["name"] + with self.engine.begin() as connection: - for stmt in self.inserts: - connection.execute(stmt) + # batch our values (which are currently 5 item dicts) so we don't + # hit the 65535 params limit + for values in batched(self.values, 10_000): + stmt = insert(self.table).values(values) + + # use the constraint for this table to drive upserting where the + # new value (excluded.value) is used to update the row + do_update_stmt = stmt.on_conflict_do_update( + constraint=constraint, + set_={"value": stmt.excluded.value}, + ) + + connection.execute(do_update_stmt) + log.info("Inserted %s rows", len(values), table=self.table.name) def write(self, date, value, **kwargs): # convert date to a timestamp # TODO: do we need to do any checking to make sure this is tz-aware and in # UTC? dt = datetime.combine(date, time()) + value = {"time": dt, "value": value, **kwargs} - # get the primary key name from the given table - constraint = inspect(self.engine).get_pk_constraint(self.table.name)["name"] - - # TODO: could we put do all the rows at once in the values() call and - # then use EXCLUDED to reference the value in the set_? - insert_stmt = ( - insert(self.table) - .values(time=dt, value=value, **kwargs) - .on_conflict_do_update( - constraint=constraint, - set_={"value": value}, - ) - ) - - self.inserts.append(insert_stmt) - - log.debug(insert_stmt) + self.values.append(value) diff --git a/metrics/tools/iter.py b/metrics/tools/iter.py new file mode 100644 index 00000000..bafd23b1 --- /dev/null +++ b/metrics/tools/iter.py @@ -0,0 +1,14 @@ +import itertools + + +def batched(iterable, n): + """ + Backport of 3.12's itertools.batched + + https://docs.python.org/3/library/itertools.html#itertools.batched + + batched('ABCDEFG', 3) --> ABC DEF G + """ + it = iter(iterable) + while batch := tuple(itertools.islice(it, n)): + yield batch diff --git a/tests/conftest.py b/tests/conftest.py index 547d08d7..c0c5896f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ import pytest -from sqlalchemy import create_engine +from sqlalchemy import create_engine, inspect from sqlalchemy.engine import make_url from sqlalchemy_utils import create_database, database_exists, drop_database @@ -23,3 +23,11 @@ def engine(): # drop the database on test suite exit drop_database(url) + + +@pytest.fixture +def has_table(engine): + def checker(table_name): + return inspect(engine).has_table(table_name) + + return checker diff --git a/tests/metrics/timescaledb/test_writer.py b/tests/metrics/timescaledb/test_writer.py index 6a4e5903..e344c5ba 100644 --- a/tests/metrics/timescaledb/test_writer.py +++ b/tests/metrics/timescaledb/test_writer.py @@ -1,7 +1,7 @@ from datetime import date import pytest -from sqlalchemy import TIMESTAMP, Column, Integer, Table, inspect, select, text +from sqlalchemy import TIMESTAMP, Column, Integer, Table, select, text from sqlalchemy.engine import make_url from metrics.timescaledb.tables import metadata @@ -43,15 +43,15 @@ def table(): ) -def test_timescaledbwriter(engine, table): +def test_timescaledbwriter(engine, has_table, table): # check ensure_table is setting up the table - assert not inspect(engine).has_table(table.name) + assert not has_table(table.name) with TimescaleDBWriter(table, engine) as writer: for i in range(1, 4): writer.write(date(2023, 11, i), i) - assert inspect(engine).has_table(table.name) + assert has_table(table.name) # check there are timescaledb child tables # https://stackoverflow.com/questions/1461722/how-to-find-child-tables-that-inherit-from-another-table-in-psql