diff --git a/metrics/timescaledb/db.py b/metrics/timescaledb/db.py index 7b24aa3..ceb59e1 100644 --- a/metrics/timescaledb/db.py +++ b/metrics/timescaledb/db.py @@ -15,6 +15,7 @@ def reset_table(table, batch_size=None): with _get_engine().begin() as connection: _drop_table(connection, table, batch_size) + with _get_engine().begin() as connection: _ensure_table(connection, table) log.info("Reset table", table=table.name) @@ -143,7 +144,41 @@ def _drop_child_tables(connection, table): def _ensure_table(connection, table): - connection.execute(schema.CreateTable(table, if_not_exists=True)) + def alter_table(table, engine): + return f"ALTER TABLE {table.name}" + + def add_column(table, col, engine): + col_name = col.compile(dialect=engine.dialect) + col_type = col.type.compile(dialect=engine.dialect) + + if col.nullable: + nullability = "NULL" + else: + if not col.default: + raise ValueError("Adding a nullable column requires a default value") + type_processor = col.type.literal_processor(dialect=engine.dialect) + processed_default = type_processor(col.default.arg) + nullability = f"NOT NULL DEFAULT {processed_default}" + return f"{alter_table(table, engine)} ADD {col_name} {col_type} {nullability}" + + def drop_column(table, col, engine): + return f"{alter_table(table, engine)} DROP COLUMN {col.name}" + + metadata = MetaData() + engine = connection.engine + metadata.reflect(engine) + if table.name not in metadata.tables: + connection.execute(schema.CreateTable(table)) + else: + db_table = metadata.tables[table.name] + for column in [ + c for c in table.columns if c.name not in [d.name for d in db_table.columns] + ]: + connection.execute(text(add_column(table, column, engine))) + for column in [ + d for d in db_table.columns if d.name not in [c.name for c in table.columns] + ]: + connection.execute(text(drop_column(table, column, engine))) if _is_hypertable(table): connection.execute( diff --git a/tests/metrics/timescaledb/test_db.py b/tests/metrics/timescaledb/test_db.py index ee2fca8..be91688 100644 --- a/tests/metrics/timescaledb/test_db.py +++ b/tests/metrics/timescaledb/test_db.py @@ -1,7 +1,17 @@ import datetime import pytest -from sqlalchemy import TIMESTAMP, Column, Table, Text, create_engine, select, text +from sqlalchemy import ( + TIMESTAMP, + Column, + Integer, + MetaData, + Table, + Text, + create_engine, + select, + text, +) from sqlalchemy_utils import create_database, database_exists, drop_database from metrics.timescaledb import db, tables @@ -74,13 +84,43 @@ def hypertable(request): ) -def test_ensure_table(engine, table): +def test_ensure_new_table(engine, table): with engine.begin() as connection: assert not db._has_table(connection, table) db._ensure_table(connection, table) assert db._has_table(connection, table) +def test_ensure_existing_table_no_changes(engine, table): + with engine.begin() as connection: + assert not db._has_table(connection, table) + db._ensure_table(connection, table) + assert db._has_table(connection, table) + + +def test_ensure_existing_table_with_changes(engine, table): + with engine.begin() as connection: + db._ensure_table(connection, table) + with engine.begin() as connection: + test_column_nullable = Column("test", Text, nullable=True) + test_column_not_nullable = Column( + "testnotnull", Integer, nullable=False, default=123 + ) + table._columns.add(test_column_nullable) + table._columns.add(test_column_not_nullable) + db._ensure_table(connection, table) + with engine.begin() as connection: + metadata = MetaData() + metadata.reflect(engine) + assert test_column_nullable.name in metadata.tables[table.name].columns + table._columns.remove(test_column_nullable) + db._ensure_table(connection, table) + with engine.begin() as connection: + metadata = MetaData() + metadata.reflect(engine) + assert test_column_nullable.name not in metadata.tables[table.name].columns + + def test_ensure_hypertable(engine, hypertable): with engine.begin() as connection: assert not db._has_table(connection, hypertable)