diff --git a/alembic_postgresql_enum/compare_dispatch.py b/alembic_postgresql_enum/compare_dispatch.py index 1f40172..53ccba4 100644 --- a/alembic_postgresql_enum/compare_dispatch.py +++ b/alembic_postgresql_enum/compare_dispatch.py @@ -61,7 +61,9 @@ def compare_enums( schema = default_schema definitions = get_defined_enums(autogen_context.connection, schema) - declarations = get_declared_enums(autogen_context.metadata, schema, default_schema, autogen_context.connection) + declarations = get_declared_enums( + autogen_context.metadata, schema, default_schema, autogen_context.connection, upgrade_ops + ) create_new_enums(definitions, declarations.enum_values, schema, upgrade_ops) diff --git a/alembic_postgresql_enum/get_enum_data/declared_enums.py b/alembic_postgresql_enum/get_enum_data/declared_enums.py index 75706b8..77e2d61 100644 --- a/alembic_postgresql_enum/get_enum_data/declared_enums.py +++ b/alembic_postgresql_enum/get_enum_data/declared_enums.py @@ -1,9 +1,9 @@ from collections import defaultdict -from enum import Enum -from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast +from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast, Dict import sqlalchemy -from sqlalchemy import MetaData +from alembic.operations.ops import UpgradeOps, ModifyTableOps, AddColumnOp, CreateTableOp +from sqlalchemy import MetaData, Column from sqlalchemy.dialects import postgresql from alembic_postgresql_enum.sql_commands.column_default import get_column_default @@ -45,11 +45,49 @@ def column_type_is_enum(column_type: Any) -> bool: return False +def get_just_added_defaults( + upgrade_ops: Union[UpgradeOps, None], default_schema: str +) -> Dict[Tuple[str, str, str], str]: + """Get all server defaults that will be added in current migration""" + if upgrade_ops is None: + return {} + + new_server_defaults = {} + + for operations_group in upgrade_ops.ops: + if isinstance(operations_group, ModifyTableOps): + for operation in operations_group.ops: + if isinstance(operation, AddColumnOp): + try: + if operation.column.server_default is None: + continue + new_server_defaults[ + operation.schema or default_schema, operation.table_name, operation.column.name + ] = operation.column.server_default.arg.text + except AttributeError: + pass + + elif isinstance(operations_group, CreateTableOp): + for column in operations_group.columns: + if isinstance(column, Column): + try: + if column.server_default is None: + continue + new_server_defaults[column.table.schema or default_schema, column.table.name, column.name] = ( + column.server_default.arg.text + ) + except AttributeError: + pass + + return new_server_defaults + + def get_declared_enums( metadata: Union[MetaData, List[MetaData]], schema: str, default_schema: str, connection: "Connection", + upgrade_ops: Union[UpgradeOps, None] = None, ) -> DeclaredEnumValues: """ Return a dict mapping SQLAlchemy declared enumeration types to the set of their values @@ -62,6 +100,8 @@ def get_declared_enums( Default schema name, likely will be "public" :param connection: Database connection + :param upgrade_ops: + Upgrade operations in current migration :returns DeclaredEnumValues: enum_values: { "my_enum": tuple(["a", "b", "c"]), @@ -75,6 +115,10 @@ def get_declared_enums( enum_name_to_values = dict() enum_name_to_table_references: defaultdict[str, Set[TableReference]] = defaultdict(set) + just_added_defaults = get_just_added_defaults(upgrade_ops, default_schema) + + # assert just_added_defaults == {}, just_added_defaults + if isinstance(metadata, list): metadata_list = metadata else: @@ -103,6 +147,8 @@ def get_declared_enums( table_schema = table.schema or default_schema column_default = get_column_default(connection, table_schema, table.name, column.name) + if (table_schema, table.name, column.name) in just_added_defaults: + column_default = just_added_defaults[table_schema, table.name, column.name] enum_name_to_table_references[column_type.name].add( # type: ignore[attr-defined] TableReference( table_schema=table_schema, diff --git a/alembic_postgresql_enum/sql_commands/column_default.py b/alembic_postgresql_enum/sql_commands/column_default.py index cef79ba..0d21175 100644 --- a/alembic_postgresql_enum/sql_commands/column_default.py +++ b/alembic_postgresql_enum/sql_commands/column_default.py @@ -1,3 +1,4 @@ +import re from typing import TYPE_CHECKING, Union, List, Tuple import sqlalchemy @@ -54,15 +55,38 @@ def rename_default_if_required( enum_name: str, enum_values_to_rename: List[Tuple[str, str]], ) -> str: - is_array = default_value.endswith("[]") + if schema: + new_enum = f"{schema}.{enum_name}" + else: + new_enum = enum_name + + if default_value.startswith("ARRAY["): + column_default_value = _replace_strings_in_quotes(default_value, enum_values_to_rename) + column_default_value = re.sub(r"::[.\w]+", f"::{new_enum}", column_default_value) + return column_default_value + + if default_value.endswith("[]"): + + # remove old type postfix + column_default_value = default_value[: default_value.find("::")] + + column_default_value = _replace_strings_in_quotes(column_default_value, enum_values_to_rename) + + return f"{column_default_value}::{new_enum}[]" + # remove old type postfix column_default_value = default_value[: default_value.find("::")] - for old_value, new_value in enum_values_to_rename: - column_default_value = column_default_value.replace(f"'{old_value}'", f"'{new_value}'") - column_default_value = column_default_value.replace(f'"{old_value}"', f'"{new_value}"') + column_default_value = _replace_strings_in_quotes(column_default_value, enum_values_to_rename) - suffix = "[]" if is_array else "" - if schema: - return f"{column_default_value}::{schema}.{enum_name}{suffix}" - return f"{column_default_value}::{enum_name}{suffix}" + return f"{column_default_value}::{new_enum}" + + +def _replace_strings_in_quotes( + old_default: str, + enum_values_to_rename: List[Tuple[str, str]], +) -> str: + for old_value, new_value in enum_values_to_rename: + old_default = old_default.replace(f"'{old_value}'", f"'{new_value}'") + old_default = old_default.replace(f'"{old_value}"', f'"{new_value}"') + return old_default diff --git a/tests/sync_enum_values/test_rename_default_if_required.py b/tests/sync_enum_values/test_rename_default_if_required.py index a76804e..752d7e1 100644 --- a/tests/sync_enum_values/test_rename_default_if_required.py +++ b/tests/sync_enum_values/test_rename_default_if_required.py @@ -31,3 +31,36 @@ def test_array_default_value_with_schema(): old_default_value = """'{}'::test.order_status_old[]""" assert rename_default_if_required("test", old_default_value, "order_status", []) == """'{}'::test.order_status[]""" + + +def test_caps_array_default_value_without_schema(): + old_default_value = """ARRAY['A'::my_old_enum, 'B'::my_old_enum]""" + + assert ( + rename_default_if_required("test", old_default_value, "my_enum", []) + == """ARRAY['A'::test.my_enum, 'B'::test.my_enum]""" + ) + + +def test_caps_array_default_value_with_schema(): + old_default_value = """ARRAY['A'::test.my_old_enum, 'B'::test.my_old_enum]""" + + assert ( + rename_default_if_required("test", old_default_value, "my_enum", []) + == """ARRAY['A'::test.my_enum, 'B'::test.my_enum]""" + ) + + +def test_caps_array_another_default_value_without_schema(): + old_default_value = """ARRAY['A'::my_old_enum, 'B'::my_old_enum]""" + + assert ( + rename_default_if_required("test", old_default_value, "my_enum", []) + == """ARRAY['A'::test.my_enum, 'B'::test.my_enum]""" + ) + + +def test_caps_array_another_default_value_with_schema(): + old_default_value = """ARRAY['A', 'B']::test.my_old_enum[]""" + + assert rename_default_if_required("test", old_default_value, "my_enum", []) == """ARRAY['A', 'B']::test.my_enum[]""" diff --git a/tests/sync_enum_values/test_run_array_new_column.py b/tests/sync_enum_values/test_run_array_new_column.py new file mode 100644 index 0000000..8114bc6 --- /dev/null +++ b/tests/sync_enum_values/test_run_array_new_column.py @@ -0,0 +1,91 @@ +from enum import Enum +from typing import TYPE_CHECKING + +import sqlalchemy +from sqlalchemy import MetaData, Table, Column, insert +from sqlalchemy.dialects import postgresql + +from tests.base.run_migration_test_abc import CompareAndRunTestCase + +if TYPE_CHECKING: + from sqlalchemy import Connection + + +class OldEnum(Enum): + A = "a" + B = "b" + + +class NewEnum(Enum): + A = "a" + B = "b" + C = "c" + + +class TestNewArrayColumnColumn(CompareAndRunTestCase): + def get_database_schema(self) -> MetaData: + database_schema = MetaData() + Table("a", database_schema) # , Column("value", postgresql.ARRAY(postgresql.ENUM(OldEnum))) + Table( + "b", + database_schema, + Column( + "value", + postgresql.ARRAY(postgresql.ENUM(OldEnum, name="my_enum")), + server_default=sqlalchemy.text("ARRAY['A', 'B']::my_enum[]"), + ), + ) + return database_schema + + def get_target_schema(self) -> MetaData: + target_schema = MetaData() + Table( + "a", + target_schema, + Column( + "value", + postgresql.ARRAY(postgresql.ENUM(NewEnum, name="my_enum")), + server_default=sqlalchemy.text("ARRAY['A', 'B']::my_enum[]"), + ), + ) + Table( + "b", + target_schema, + Column( + "value", + postgresql.ARRAY(postgresql.ENUM(NewEnum, name="my_enum")), + server_default=sqlalchemy.text("ARRAY['A', 'B']::my_enum[]"), + ), + ) + return target_schema + + def insert_migration_data(self, connection: "Connection", database_schema: MetaData) -> None: + a_table = database_schema.tables["a"] + connection.execute( + insert(a_table).values( + [ + {}, + {}, + ] + ) + ) + + def get_expected_upgrade(self) -> str: + return """ + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('a', sa.Column('value', postgresql.ARRAY(postgresql.ENUM('A', 'B', 'C', name='my_enum', create_type=False)), server_default=sa.text("ARRAY['A', 'B']::my_enum[]"), nullable=True)) + op.sync_enum_values('public', 'my_enum', ['A', 'B', 'C'], + [TableReference(table_schema='public', table_name='b', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A'::my_enum, 'B'::my_enum]"), TableReference(table_schema='public', table_name='a', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A', 'B']::my_enum[]")], + enum_values_to_rename=[]) + # ### end Alembic commands ### + """ + + def get_expected_downgrade(self) -> str: + return """ + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values('public', 'my_enum', ['A', 'B'], + [TableReference(table_schema='public', table_name='b', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A'::my_enum, 'B'::my_enum]"), TableReference(table_schema='public', table_name='a', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A', 'B']::my_enum[]")], + enum_values_to_rename=[]) + op.drop_column('a', 'value') + # ### end Alembic commands ### + """