Skip to content

Commit

Permalink
Fix case sensitivity of table name by quoting it
Browse files Browse the repository at this point in the history
  • Loading branch information
artem.golovin committed Jun 29, 2024
1 parent 455f877 commit d2855ce
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 12 deletions.
4 changes: 3 additions & 1 deletion alembic_postgresql_enum/compare_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def compare_enums(
schema = default_schema

definitions = get_defined_enums(autogen_context.connection, schema)
declarations = get_declared_enums(autogen_context.metadata, schema, default_schema, autogen_context.connection)
declarations = get_declared_enums(
autogen_context.metadata, schema, default_schema, autogen_context.connection, upgrade_ops
)

create_new_enums(definitions, declarations.enum_values, schema, upgrade_ops)

Expand Down
52 changes: 49 additions & 3 deletions alembic_postgresql_enum/get_enum_data/declared_enums.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from collections import defaultdict
from enum import Enum
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast, Dict

import sqlalchemy
from sqlalchemy import MetaData
from alembic.operations.ops import UpgradeOps, ModifyTableOps, AddColumnOp, CreateTableOp
from sqlalchemy import MetaData, Column
from sqlalchemy.dialects import postgresql

from alembic_postgresql_enum.sql_commands.column_default import get_column_default
Expand Down Expand Up @@ -45,11 +45,49 @@ def column_type_is_enum(column_type: Any) -> bool:
return False


def get_just_added_defaults(
upgrade_ops: Union[UpgradeOps, None], default_schema: str
) -> Dict[Tuple[str, str, str], str]:
"""Get all server defaults that will be added in current migration"""
if upgrade_ops is None:
return {}

new_server_defaults = {}

for operations_group in upgrade_ops.ops:
if isinstance(operations_group, ModifyTableOps):
for operation in operations_group.ops:
if isinstance(operation, AddColumnOp):
try:
if operation.column.server_default is None:
continue
new_server_defaults[
operation.schema or default_schema, operation.table_name, operation.column.name
] = operation.column.server_default.arg.text
except AttributeError:
pass

elif isinstance(operations_group, CreateTableOp):
for column in operations_group.columns:
if isinstance(column, Column):
try:
if column.server_default is None:
continue
new_server_defaults[column.table.schema or default_schema, column.table.name, column.name] = (
column.server_default.arg.text
)
except AttributeError:
pass

return new_server_defaults


def get_declared_enums(
metadata: Union[MetaData, List[MetaData]],
schema: str,
default_schema: str,
connection: "Connection",
upgrade_ops: Union[UpgradeOps, None] = None,
) -> DeclaredEnumValues:
"""
Return a dict mapping SQLAlchemy declared enumeration types to the set of their values
Expand All @@ -62,6 +100,8 @@ def get_declared_enums(
Default schema name, likely will be "public"
:param connection:
Database connection
:param upgrade_ops:
Upgrade operations in current migration
:returns DeclaredEnumValues:
enum_values: {
"my_enum": tuple(["a", "b", "c"]),
Expand All @@ -75,6 +115,10 @@ def get_declared_enums(
enum_name_to_values = dict()
enum_name_to_table_references: defaultdict[str, Set[TableReference]] = defaultdict(set)

just_added_defaults = get_just_added_defaults(upgrade_ops, default_schema)

# assert just_added_defaults == {}, just_added_defaults

if isinstance(metadata, list):
metadata_list = metadata
else:
Expand Down Expand Up @@ -103,6 +147,8 @@ def get_declared_enums(

table_schema = table.schema or default_schema
column_default = get_column_default(connection, table_schema, table.name, column.name)
if (table_schema, table.name, column.name) in just_added_defaults:
column_default = just_added_defaults[table_schema, table.name, column.name]
enum_name_to_table_references[column_type.name].add( # type: ignore[attr-defined]
TableReference(
table_schema=table_schema,
Expand Down
40 changes: 32 additions & 8 deletions alembic_postgresql_enum/sql_commands/column_default.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import TYPE_CHECKING, Union, List, Tuple

import sqlalchemy
Expand Down Expand Up @@ -54,15 +55,38 @@ def rename_default_if_required(
enum_name: str,
enum_values_to_rename: List[Tuple[str, str]],
) -> str:
is_array = default_value.endswith("[]")
if schema:
new_enum = f"{schema}.{enum_name}"
else:
new_enum = enum_name

if default_value.startswith("ARRAY["):
column_default_value = _replace_strings_in_quotes(default_value, enum_values_to_rename)
column_default_value = re.sub(r"::[.\w]+", f"::{new_enum}", column_default_value)
return column_default_value

if default_value.endswith("[]"):

# remove old type postfix
column_default_value = default_value[: default_value.find("::")]

column_default_value = _replace_strings_in_quotes(column_default_value, enum_values_to_rename)

return f"{column_default_value}::{new_enum}[]"

# remove old type postfix
column_default_value = default_value[: default_value.find("::")]

for old_value, new_value in enum_values_to_rename:
column_default_value = column_default_value.replace(f"'{old_value}'", f"'{new_value}'")
column_default_value = column_default_value.replace(f'"{old_value}"', f'"{new_value}"')
column_default_value = _replace_strings_in_quotes(column_default_value, enum_values_to_rename)

suffix = "[]" if is_array else ""
if schema:
return f"{column_default_value}::{schema}.{enum_name}{suffix}"
return f"{column_default_value}::{enum_name}{suffix}"
return f"{column_default_value}::{new_enum}"


def _replace_strings_in_quotes(
old_default: str,
enum_values_to_rename: List[Tuple[str, str]],
) -> str:
for old_value, new_value in enum_values_to_rename:
old_default = old_default.replace(f"'{old_value}'", f"'{new_value}'")
old_default = old_default.replace(f'"{old_value}"', f'"{new_value}"')
return old_default
33 changes: 33 additions & 0 deletions tests/sync_enum_values/test_rename_default_if_required.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,36 @@ def test_array_default_value_with_schema():
old_default_value = """'{}'::test.order_status_old[]"""

assert rename_default_if_required("test", old_default_value, "order_status", []) == """'{}'::test.order_status[]"""


def test_caps_array_default_value_without_schema():
old_default_value = """ARRAY['A'::my_old_enum, 'B'::my_old_enum]"""

assert (
rename_default_if_required("test", old_default_value, "my_enum", [])
== """ARRAY['A'::test.my_enum, 'B'::test.my_enum]"""
)


def test_caps_array_default_value_with_schema():
old_default_value = """ARRAY['A'::test.my_old_enum, 'B'::test.my_old_enum]"""

assert (
rename_default_if_required("test", old_default_value, "my_enum", [])
== """ARRAY['A'::test.my_enum, 'B'::test.my_enum]"""
)


def test_caps_array_another_default_value_without_schema():
old_default_value = """ARRAY['A'::my_old_enum, 'B'::my_old_enum]"""

assert (
rename_default_if_required("test", old_default_value, "my_enum", [])
== """ARRAY['A'::test.my_enum, 'B'::test.my_enum]"""
)


def test_caps_array_another_default_value_with_schema():
old_default_value = """ARRAY['A', 'B']::test.my_old_enum[]"""

assert rename_default_if_required("test", old_default_value, "my_enum", []) == """ARRAY['A', 'B']::test.my_enum[]"""
91 changes: 91 additions & 0 deletions tests/sync_enum_values/test_run_array_new_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from enum import Enum
from typing import TYPE_CHECKING

import sqlalchemy
from sqlalchemy import MetaData, Table, Column, insert
from sqlalchemy.dialects import postgresql

from tests.base.run_migration_test_abc import CompareAndRunTestCase

if TYPE_CHECKING:
from sqlalchemy import Connection


class OldEnum(Enum):
A = "a"
B = "b"


class NewEnum(Enum):
A = "a"
B = "b"
C = "c"


class TestNewArrayColumnColumn(CompareAndRunTestCase):
def get_database_schema(self) -> MetaData:
database_schema = MetaData()
Table("a", database_schema) # , Column("value", postgresql.ARRAY(postgresql.ENUM(OldEnum)))
Table(
"b",
database_schema,
Column(
"value",
postgresql.ARRAY(postgresql.ENUM(OldEnum, name="my_enum")),
server_default=sqlalchemy.text("ARRAY['A', 'B']::my_enum[]"),
),
)
return database_schema

def get_target_schema(self) -> MetaData:
target_schema = MetaData()
Table(
"a",
target_schema,
Column(
"value",
postgresql.ARRAY(postgresql.ENUM(NewEnum, name="my_enum")),
server_default=sqlalchemy.text("ARRAY['A', 'B']::my_enum[]"),
),
)
Table(
"b",
target_schema,
Column(
"value",
postgresql.ARRAY(postgresql.ENUM(NewEnum, name="my_enum")),
server_default=sqlalchemy.text("ARRAY['A', 'B']::my_enum[]"),
),
)
return target_schema

def insert_migration_data(self, connection: "Connection", database_schema: MetaData) -> None:
a_table = database_schema.tables["a"]
connection.execute(
insert(a_table).values(
[
{},
{},
]
)
)

def get_expected_upgrade(self) -> str:
return """
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('a', sa.Column('value', postgresql.ARRAY(postgresql.ENUM('A', 'B', 'C', name='my_enum', create_type=False)), server_default=sa.text("ARRAY['A', 'B']::my_enum[]"), nullable=True))
op.sync_enum_values('public', 'my_enum', ['A', 'B', 'C'],
[TableReference(table_schema='public', table_name='b', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A'::my_enum, 'B'::my_enum]"), TableReference(table_schema='public', table_name='a', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A', 'B']::my_enum[]")],
enum_values_to_rename=[])
# ### end Alembic commands ###
"""

def get_expected_downgrade(self) -> str:
return """
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values('public', 'my_enum', ['A', 'B'],
[TableReference(table_schema='public', table_name='b', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A'::my_enum, 'B'::my_enum]"), TableReference(table_schema='public', table_name='a', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A', 'B']::my_enum[]")],
enum_values_to_rename=[])
op.drop_column('a', 'value')
# ### end Alembic commands ###
"""

0 comments on commit d2855ce

Please sign in to comment.