Skip to content

Commit

Permalink
Merge pull request #52 from ebmdatalab/bulk-upsert
Browse files Browse the repository at this point in the history
Switch to bulk upsert
  • Loading branch information
ghickman authored Nov 28, 2023
2 parents e5793a4 + 22459af commit a935950
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 26 deletions.
43 changes: 22 additions & 21 deletions metrics/timescaledb/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from sqlalchemy import create_engine, inspect, schema, text
from sqlalchemy.dialects.postgresql import insert

from ..tools.iter import batched


log = structlog.get_logger()

Expand All @@ -31,45 +33,44 @@ def ensure_table(engine, table):


class TimescaleDBWriter:
inserts = []

def __init__(self, table, engine=None):
if engine is None:
engine = create_engine(TIMESCALEDB_URL)

self.engine = engine
self.table = table
self.values = []

def __enter__(self):
ensure_table(self.engine, self.table)

return self

def __exit__(self, *args):
# get the primary key name from the given table
constraint = inspect(self.engine).get_pk_constraint(self.table.name)["name"]

with self.engine.begin() as connection:
for stmt in self.inserts:
connection.execute(stmt)
# batch our values (which are currently 5 item dicts) so we don't
# hit the 65535 params limit
for values in batched(self.values, 10_000):
stmt = insert(self.table).values(values)

# use the constraint for this table to drive upserting where the
# new value (excluded.value) is used to update the row
do_update_stmt = stmt.on_conflict_do_update(
constraint=constraint,
set_={"value": stmt.excluded.value},
)

connection.execute(do_update_stmt)
log.info("Inserted %s rows", len(values), table=self.table.name)

def write(self, date, value, **kwargs):
# convert date to a timestamp
# TODO: do we need to do any checking to make sure this is tz-aware and in
# UTC?
dt = datetime.combine(date, time())
value = {"time": dt, "value": value, **kwargs}

# get the primary key name from the given table
constraint = inspect(self.engine).get_pk_constraint(self.table.name)["name"]

# TODO: could we put do all the rows at once in the values() call and
# then use EXCLUDED to reference the value in the set_?
insert_stmt = (
insert(self.table)
.values(time=dt, value=value, **kwargs)
.on_conflict_do_update(
constraint=constraint,
set_={"value": value},
)
)

self.inserts.append(insert_stmt)

log.debug(insert_stmt)
self.values.append(value)
14 changes: 14 additions & 0 deletions metrics/tools/iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import itertools


def batched(iterable, n):
"""
Backport of 3.12's itertools.batched
https://docs.python.org/3/library/itertools.html#itertools.batched
batched('ABCDEFG', 3) --> ABC DEF G
"""
it = iter(iterable)
while batch := tuple(itertools.islice(it, n)):
yield batch
10 changes: 9 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine import make_url
from sqlalchemy_utils import create_database, database_exists, drop_database

Expand All @@ -23,3 +23,11 @@ def engine():

# drop the database on test suite exit
drop_database(url)


@pytest.fixture
def has_table(engine):
def checker(table_name):
return inspect(engine).has_table(table_name)

return checker
8 changes: 4 additions & 4 deletions tests/metrics/timescaledb/test_writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import date

import pytest
from sqlalchemy import TIMESTAMP, Column, Integer, Table, inspect, select, text
from sqlalchemy import TIMESTAMP, Column, Integer, Table, select, text
from sqlalchemy.engine import make_url

from metrics.timescaledb.tables import metadata
Expand Down Expand Up @@ -43,15 +43,15 @@ def table():
)


def test_timescaledbwriter(engine, table):
def test_timescaledbwriter(engine, has_table, table):
# check ensure_table is setting up the table
assert not inspect(engine).has_table(table.name)
assert not has_table(table.name)

with TimescaleDBWriter(table, engine) as writer:
for i in range(1, 4):
writer.write(date(2023, 11, i), i)

assert inspect(engine).has_table(table.name)
assert has_table(table.name)

# check there are timescaledb child tables
# https://stackoverflow.com/questions/1461722/how-to-find-child-tables-that-inherit-from-another-table-in-psql
Expand Down

0 comments on commit a935950

Please sign in to comment.