Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: make DDL overridable for column ADD, ALTER, and RENAME operations #1114

Merged
merged 11 commits into from
Nov 9, 2022
118 changes: 92 additions & 26 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,21 +630,10 @@ def _create_empty_column(
if not self.allow_column_add:
raise NotImplementedError("Adding columns is not supported.")

create_column_clause = sqlalchemy.schema.CreateColumn(
sqlalchemy.Column(
column_name,
sql_type,
)
)
self.connection.execute(
sqlalchemy.DDL(
"ALTER TABLE %(table)s ADD COLUMN %(create_column)s",
{
"table": full_table_name,
"create_column": create_column_clause,
},
)
column_add_ddl = self.get_column_add_ddl(
table_name=full_table_name, column_name=column_name, column_type=sql_type
)
self.connection.execute(column_add_ddl)

def prepare_schema(self, schema_name: str) -> None:
"""Create the target database schema.
Expand Down Expand Up @@ -729,10 +718,10 @@ def rename_column(self, full_table_name: str, old_name: str, new_name: str) -> N
if not self.allow_column_rename:
raise NotImplementedError("Renaming columns is not supported.")

self.connection.execute(
f"ALTER TABLE {full_table_name} "
f'RENAME COLUMN "{old_name}" to "{new_name}"'
column_rename_ddl = self.get_column_rename_ddl(
table_name=full_table_name, column_name=old_name, new_column_name=new_name
)
self.connection.execute(column_rename_ddl)

def merge_sql_types(
self, sql_types: list[sqlalchemy.types.TypeEngine]
Expand Down Expand Up @@ -871,6 +860,87 @@ def _get_column_type(

return cast(sqlalchemy.types.TypeEngine, column.type)

@staticmethod
def get_column_add_ddl(
table_name: str, column_name: str, column_type: sqlalchemy.types.TypeEngine
) -> sqlalchemy.DDL:
"""Get the create column DDL statement.

Override this if your database uses a different syntax for creating columns.

Args:
table_name: Fully qualified table name of column to alter.
column_name: Column name to create.
column_type: New column sqlalchemy type.

Returns:
A sqlalchemy DDL instance.
"""
create_column_clause = sqlalchemy.schema.CreateColumn(
sqlalchemy.Column(
column_name,
column_type,
)
)
return sqlalchemy.DDL(
"ALTER TABLE %(table_name)s ADD COLUMN %(create_column_clause)s",
{
"table_name": table_name,
"create_column_clause": create_column_clause,
},
)

@staticmethod
def get_column_rename_ddl(
table_name: str, column_name: str, new_column_name: str
) -> sqlalchemy.DDL:
"""Get the create column DDL statement.

Override this if your database uses a different syntax for renaming columns.

Args:
table_name: Fully qualified table name of column to alter.
column_name: Existing column name.
new_column_name: New column name.

Returns:
A sqlalchemy DDL instance.
"""
return sqlalchemy.DDL(
"ALTER TABLE %(table_name)s "
"RENAME COLUMN %(column_name)s to %(new_column_name)s",
{
"table_name": table_name,
"column_name": column_name,
"new_column_name": new_column_name,
},
)

@staticmethod
def get_column_alter_ddl(
table_name: str, column_name: str, column_type: sqlalchemy.types.TypeEngine
) -> sqlalchemy.DDL:
"""Get the alter column DDL statement.

Override this if your database uses a different syntax for altering columns.

Args:
table_name: Fully qualified table name of column to alter.
column_name: Column name to alter.
column_type: New column type string.

Returns:
A sqlalchemy DDL instance.
"""
return sqlalchemy.DDL(
"ALTER TABLE %(table_name)s ALTER COLUMN %(column_name)s (%(column_type)s)",
{
"table_name": table_name,
"column_name": column_name,
"column_type": column_type,
},
)

def _adapt_column_type(
self,
full_table_name: str,
Expand Down Expand Up @@ -912,13 +982,9 @@ def _adapt_column_type(
f"from '{current_type}' to '{compatible_sql_type}'."
)

self.connection.execute(
sqlalchemy.DDL(
"ALTER TABLE %(table)s ALTER COLUMN %(col_name)s (%(col_type)s)",
{
"table": full_table_name,
"col_name": column_name,
"col_type": compatible_sql_type,
},
)
alter_column_ddl = self.get_column_alter_ddl(
table_name=full_table_name,
column_name=column_name,
column_type=compatible_sql_type,
)
self.connection.execute(alter_column_ddl)
93 changes: 93 additions & 0 deletions tests/core/test_connector_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest
import sqlalchemy
from sqlalchemy.dialects import sqlite

from singer_sdk.connectors import SQLConnector


def stringify(in_dict):
return {k: str(v) for k, v in in_dict.items()}


class TestConnectorSQL:
"""Test the SQLConnector class."""

@pytest.fixture()
def connector(self):
return SQLConnector()

@pytest.mark.parametrize(
"method_name,kwargs,context,unrendered_statement,rendered_statement",
[
(
"get_column_add_ddl",
{
"table_name": "full.table.name",
"column_name": "column_name",
"column_type": sqlalchemy.types.Text(),
},
{
"table_name": "full.table.name",
"create_column_clause": sqlalchemy.schema.CreateColumn(
sqlalchemy.Column(
"column_name",
sqlalchemy.types.Text(),
)
),
},
"ALTER TABLE %(table_name)s ADD COLUMN %(create_column_clause)s",
"ALTER TABLE full.table.name ADD COLUMN column_name TEXT",
),
(
"get_column_rename_ddl",
{
"table_name": "full.table.name",
"column_name": "old_name",
"new_column_name": "new_name",
},
{
"table_name": "full.table.name",
"column_name": "old_name",
"new_column_name": "new_name",
},
"ALTER TABLE %(table_name)s RENAME COLUMN %(column_name)s to %(new_column_name)s",
"ALTER TABLE full.table.name RENAME COLUMN old_name to new_name",
),
(
"get_column_alter_ddl",
{
"table_name": "full.table.name",
"column_name": "column_name",
"column_type": sqlalchemy.types.String(),
},
{
"table_name": "full.table.name",
"column_name": "column_name",
"column_type": sqlalchemy.types.String(),
},
"ALTER TABLE %(table_name)s ALTER COLUMN %(column_name)s (%(column_type)s)",
"ALTER TABLE full.table.name ALTER COLUMN column_name (VARCHAR)",
),
],
)
def test_get_column_ddl(
self,
connector,
method_name,
kwargs,
context,
unrendered_statement,
rendered_statement,
):
method = getattr(connector, method_name)
column_ddl = method(**kwargs)

assert stringify(column_ddl.context) == stringify(context)
assert column_ddl.statement == unrendered_statement

statement = str(
column_ddl.compile(
dialect=sqlite.dialect(), compile_kwargs={"literal_binds": True}
)
)
assert statement == rendered_statement