Skip to content

Commit

Permalink
Merge pull request #235 from ebmdatalab/benbc/connection-refactor
Browse files Browse the repository at this point in the history
Reuse connections when resetting tables
  • Loading branch information
benbc authored Sep 24, 2024
2 parents 8f972dc + 6d5404d commit efe48b1
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 54 deletions.
71 changes: 35 additions & 36 deletions metrics/timescaledb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@


def reset_table(table, batch_size=None):
_drop_table(table, batch_size)
_ensure_table(table)
log.info("Reset table", table=table.name)
with _get_engine().begin() as connection:
_drop_table(connection, table, batch_size)
_ensure_table(connection, table)
log.info("Reset table", table=table.name)


def write(table, rows):
Expand All @@ -28,15 +29,15 @@ def write(table, rows):


def upsert(table, rows):
_ensure_table(table)
batch_size = _batch_size(table)
non_pk_columns = set(table.columns) - set(table.primary_key.columns)
with _get_engine().begin() as connection:
_ensure_table(connection, table)
batch_size = _batch_size(table)
non_pk_columns = set(table.columns) - set(table.primary_key.columns)

# use the primary key constraint to match rows to be updated,
# using default constraint name if not otherwise specified
constraint = table.primary_key.name or table.name + "_pkey"
# use the primary key constraint to match rows to be updated,
# using default constraint name if not otherwise specified
constraint = table.primary_key.name or table.name + "_pkey"

with _get_engine().begin() as connection:
for values in batched(rows, batch_size):
# Construct a PostgreSQL "INSERT..ON CONFLICT" style upsert statement
# https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#insert-on-conflict-upsert
Expand Down Expand Up @@ -68,30 +69,29 @@ def _batch_size(table):
return max_params // len(table.columns)


def _drop_table(table, batch_size):
with _get_engine().begin() as connection:
log.debug("Removing table: %s", table.name)
def _drop_table(connection, table, batch_size):
log.debug("Removing table: %s", table.name)

if not _has_table(connection, table):
return
if not _has_table(connection, table):
return

if _is_hypertable(table):
# We have limited shared memory in our hosted database, so we can't DROP or
# TRUNCATE our hypertables. Instead for each "raw" table we need to:
# * empty the raw rows (from the named table) in batches
# * drop the sharded "child" tables in batches
# * drop the now empty raw table
while _has_rows(connection, table):
_delete_rows(connection, table, batch_size)
if _is_hypertable(table):
# We have limited shared memory in our hosted database, so we can't DROP or
# TRUNCATE our hypertables. Instead for each "raw" table we need to:
# * empty the raw rows (from the named table) in batches
# * drop the sharded "child" tables in batches
# * drop the now empty raw table
while _has_rows(connection, table):
_delete_rows(connection, table, batch_size)

log.debug("Removed all raw rows", table=table.name)
log.debug("Removed all raw rows", table=table.name)

_drop_child_tables(connection, table)
log.debug("Removed all child tables", table=table.name)
_drop_child_tables(connection, table)
log.debug("Removed all child tables", table=table.name)

connection.execute(text(f"DROP TABLE {table.name}"))
connection.execute(text(f"DROP TABLE {table.name}"))

log.debug("Removed raw table", table=table.name)
log.debug("Removed raw table", table=table.name)


def _has_table(connection, table):
Expand Down Expand Up @@ -142,16 +142,15 @@ def _drop_child_tables(connection, table):
connection.execute(text(f"DROP TABLE IF EXISTS {tables}"))


def _ensure_table(table):
with _get_engine().begin() as connection:
connection.execute(schema.CreateTable(table, if_not_exists=True))
def _ensure_table(connection, table):
connection.execute(schema.CreateTable(table, if_not_exists=True))

if _is_hypertable(table):
connection.execute(
text(
f"SELECT create_hypertable('{table.name}', 'time', if_not_exists => TRUE);"
)
if _is_hypertable(table):
connection.execute(
text(
f"SELECT create_hypertable('{table.name}', 'time', if_not_exists => TRUE);"
)
)


@functools.cache
Expand Down
32 changes: 14 additions & 18 deletions tests/metrics/timescaledb/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def get_rows(engine, table):
return connection.execute(select(table)).all()


def assert_is_hypertable(engine, table):
def assert_is_hypertable(connection, engine, table):
# check there are timescaledb child tables
# https://stackoverflow.com/questions/1461722/how-to-find-child-tables-that-inherit-from-another-table-in-psql

sql = """
SELECT
count(*)
Expand All @@ -46,8 +49,7 @@ def assert_is_hypertable(engine, table):
trigger_name = 'ts_insert_blocker';
"""

with engine.connect() as connection:
result = connection.execute(text(sql), {"table_name": table.name}).fetchone()
result = connection.execute(text(sql), {"table_name": table.name}).fetchone()

# We should have one trigger called ts_insert_blocker for a hypertable.
assert result[0] == 1, result
Expand Down Expand Up @@ -75,25 +77,16 @@ def hypertable(request):
def test_ensure_table(engine, table):
with engine.begin() as connection:
assert not db._has_table(connection, table)

db._ensure_table(table)

with engine.begin() as connection:
db._ensure_table(connection, table)
assert db._has_table(connection, table)


def test_ensure_hypertable(engine, hypertable):
with engine.begin() as connection:
assert not db._has_table(connection, hypertable)

db._ensure_table(hypertable)

with engine.begin() as connection:
db._ensure_table(connection, hypertable)
assert db._has_table(connection, hypertable)

# check there are timescaledb child tables
# https://stackoverflow.com/questions/1461722/how-to-find-child-tables-that-inherit-from-another-table-in-psql
assert_is_hypertable(engine, hypertable)
assert_is_hypertable(connection, engine, hypertable)


def test_get_url(monkeypatch):
Expand All @@ -114,7 +107,8 @@ def test_get_url_with_prefix(monkeypatch):


def test_reset_table(engine, table):
db._ensure_table(table)
with engine.begin() as connection:
db._ensure_table(connection, table)

# put enough rows in the db to make sure we exercise the batch removal of rows
batch_size = 5
Expand All @@ -126,7 +120,8 @@ def test_reset_table(engine, table):


def test_reset_hypertable(engine, hypertable):
db._ensure_table(hypertable)
with engine.begin() as connection:
db._ensure_table(connection, hypertable)

# put enough rows in the db to make sure we exercise the batch removal of rows
batch_size = 5
Expand Down Expand Up @@ -156,7 +151,8 @@ def check_reset(batch_size, engine, rows, table):

def test_write(engine, table):
# set up a table to write to
db._ensure_table(table)
with engine.begin() as connection:
db._ensure_table(connection, table)

rows = [{"value": "write" + str(i)} for i in range(1, 4)]
db.write(table, rows)
Expand Down

0 comments on commit efe48b1

Please sign in to comment.