diff --git a/.github/workflows/test_on_push.yaml b/.github/workflows/test_on_push.yaml index ac2be5a..f0299e9 100644 --- a/.github/workflows/test_on_push.yaml +++ b/.github/workflows/test_on_push.yaml @@ -10,7 +10,6 @@ on: - tests/** - alembic_postgresql_enum/** - .github/workflows/test_on_push.yaml - pull_request: { } jobs: run_tests: diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..68360c3 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,10 @@ +FROM python:latest + +COPY ./alembic_postgresql_enum ./alembic_postgresql_enum +COPY ./tests ./tests + +WORKDIR ./tests + +RUN pip install -r requirements.txt + +ENTRYPOINT pytest diff --git a/alembic_postgresql_enum/compare_dispatch.py b/alembic_postgresql_enum/compare_dispatch.py index 1f40172..ad5d384 100644 --- a/alembic_postgresql_enum/compare_dispatch.py +++ b/alembic_postgresql_enum/compare_dispatch.py @@ -32,6 +32,13 @@ def compare_enums( for each defined enum that has changed new entries when compared to its declared version. """ + assert ( + autogen_context.dialect is not None + and autogen_context.dialect.default_schema_name is not None + and autogen_context.connection is not None + and autogen_context.metadata is not None + ) + if autogen_context.dialect.name != "postgresql": log.warning( f"This library only supports postgresql, but you are using {autogen_context.dialect.name}, skipping" @@ -49,19 +56,15 @@ def compare_enums( if isinstance(operations_group, CreateTableOp) and operations_group.schema not in schema_names: schema_names.append(operations_group.schema) - assert ( - autogen_context.dialect is not None - and autogen_context.dialect.default_schema_name is not None - and autogen_context.connection is not None - and autogen_context.metadata is not None - ) for schema in schema_names: default_schema = autogen_context.dialect.default_schema_name if schema is None: schema = default_schema definitions = get_defined_enums(autogen_context.connection, schema) - declarations = get_declared_enums(autogen_context.metadata, schema, default_schema, autogen_context.connection) + declarations = get_declared_enums( + autogen_context.metadata, schema, default_schema, autogen_context.connection, upgrade_ops + ) create_new_enums(definitions, declarations.enum_values, schema, upgrade_ops) diff --git a/alembic_postgresql_enum/detection_of_changes/enum_alteration.py b/alembic_postgresql_enum/detection_of_changes/enum_alteration.py index bc5c7ca..1d92c6b 100644 --- a/alembic_postgresql_enum/detection_of_changes/enum_alteration.py +++ b/alembic_postgresql_enum/detection_of_changes/enum_alteration.py @@ -47,6 +47,9 @@ def sync_changed_enums( enum_name, list(old_values), list(new_values), - list(affected_columns), + sorted( # Sort references alphabetically for consistency of generated text + affected_columns, + key=lambda reference: (reference.table_schema, reference.table_name, reference.column_name), + ), ) upgrade_ops.ops.append(op) diff --git a/alembic_postgresql_enum/get_enum_data/declared_enums.py b/alembic_postgresql_enum/get_enum_data/declared_enums.py index 75706b8..98e0200 100644 --- a/alembic_postgresql_enum/get_enum_data/declared_enums.py +++ b/alembic_postgresql_enum/get_enum_data/declared_enums.py @@ -1,11 +1,12 @@ from collections import defaultdict -from enum import Enum -from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast +from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast, Optional import sqlalchemy +from alembic.operations.ops import UpgradeOps from sqlalchemy import MetaData from sqlalchemy.dialects import postgresql +from alembic_postgresql_enum.get_enum_data.get_default_from_alembic_ops import get_just_added_defaults from alembic_postgresql_enum.sql_commands.column_default import get_column_default if TYPE_CHECKING: @@ -50,6 +51,7 @@ def get_declared_enums( schema: str, default_schema: str, connection: "Connection", + upgrade_ops: Optional[UpgradeOps] = None, ) -> DeclaredEnumValues: """ Return a dict mapping SQLAlchemy declared enumeration types to the set of their values @@ -62,6 +64,8 @@ def get_declared_enums( Default schema name, likely will be "public" :param connection: Database connection + :param upgrade_ops: + Upgrade operations in current migration :returns DeclaredEnumValues: enum_values: { "my_enum": tuple(["a", "b", "c"]), @@ -75,6 +79,8 @@ def get_declared_enums( enum_name_to_values = dict() enum_name_to_table_references: defaultdict[str, Set[TableReference]] = defaultdict(set) + just_added_defaults = get_just_added_defaults(upgrade_ops, default_schema) + if isinstance(metadata, list): metadata_list = metadata else: @@ -103,6 +109,8 @@ def get_declared_enums( table_schema = table.schema or default_schema column_default = get_column_default(connection, table_schema, table.name, column.name) + if (table_schema, table.name, column.name) in just_added_defaults: + column_default = just_added_defaults[table_schema, table.name, column.name] enum_name_to_table_references[column_type.name].add( # type: ignore[attr-defined] TableReference( table_schema=table_schema, diff --git a/alembic_postgresql_enum/get_enum_data/get_default_from_alembic_ops.py b/alembic_postgresql_enum/get_enum_data/get_default_from_alembic_ops.py new file mode 100644 index 0000000..785ecf3 --- /dev/null +++ b/alembic_postgresql_enum/get_enum_data/get_default_from_alembic_ops.py @@ -0,0 +1,73 @@ +from typing import Optional, Dict, Tuple + +from alembic.operations.ops import UpgradeOps, ModifyTableOps, AddColumnOp, AlterColumnOp, CreateTableOp +from sqlalchemy import Column + +SchemaName = str +TableName = str +ColumnName = str +ColumnLocation = Tuple[SchemaName, TableName, ColumnName] + + +def _get_default_from_add_column_op(op: AddColumnOp, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]: + if op.column.server_default is None: + raise AttributeError("No new server_default") + return ( + (op.schema or default_schema, op.table_name, op.column.name), + op.column.server_default.arg.text, # type: ignore[attr-defined] + ) + + +def _get_default_from_alter_column_op(op: AlterColumnOp, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]: + if op.modify_server_default is False: + raise AttributeError("No new server_default") + return (op.schema or default_schema, op.table_name, op.column_name), op.modify_server_default + + +def _get_default_from_column(column: Column, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]: + if column.server_default is None: + raise AttributeError("No new server_default") + return ( + (column.table.schema or default_schema, column.table.name, column.name), + column.server_default.arg.text, # type: ignore[attr-defined] + ) + + +def get_just_added_defaults( + upgrade_ops: Optional[UpgradeOps], default_schema: str +) -> Dict[ColumnLocation, Optional[str]]: + """Get all server defaults that will be added in current migration""" + if upgrade_ops is None: + return {} + + new_server_defaults = {} + + for operations_group in upgrade_ops.ops: + if isinstance(operations_group, ModifyTableOps): + for operation in operations_group.ops: + if isinstance(operation, AddColumnOp): + try: + column_location, column_new_default = _get_default_from_add_column_op(operation, default_schema) + new_server_defaults[column_location] = column_new_default + except AttributeError: + pass + + elif isinstance(operation, AlterColumnOp): + try: + column_location, column_new_default = _get_default_from_alter_column_op( + operation, default_schema + ) + new_server_defaults[column_location] = column_new_default + except AttributeError: + pass + + elif isinstance(operations_group, CreateTableOp): + for column in operations_group.columns: + if isinstance(column, Column): + try: + column_location, column_new_default = _get_default_from_column(column, default_schema) + new_server_defaults[column_location] = column_new_default + except AttributeError: + pass + + return new_server_defaults diff --git a/alembic_postgresql_enum/sql_commands/column_default.py b/alembic_postgresql_enum/sql_commands/column_default.py index cef79ba..0d21175 100644 --- a/alembic_postgresql_enum/sql_commands/column_default.py +++ b/alembic_postgresql_enum/sql_commands/column_default.py @@ -1,3 +1,4 @@ +import re from typing import TYPE_CHECKING, Union, List, Tuple import sqlalchemy @@ -54,15 +55,38 @@ def rename_default_if_required( enum_name: str, enum_values_to_rename: List[Tuple[str, str]], ) -> str: - is_array = default_value.endswith("[]") + if schema: + new_enum = f"{schema}.{enum_name}" + else: + new_enum = enum_name + + if default_value.startswith("ARRAY["): + column_default_value = _replace_strings_in_quotes(default_value, enum_values_to_rename) + column_default_value = re.sub(r"::[.\w]+", f"::{new_enum}", column_default_value) + return column_default_value + + if default_value.endswith("[]"): + + # remove old type postfix + column_default_value = default_value[: default_value.find("::")] + + column_default_value = _replace_strings_in_quotes(column_default_value, enum_values_to_rename) + + return f"{column_default_value}::{new_enum}[]" + # remove old type postfix column_default_value = default_value[: default_value.find("::")] - for old_value, new_value in enum_values_to_rename: - column_default_value = column_default_value.replace(f"'{old_value}'", f"'{new_value}'") - column_default_value = column_default_value.replace(f'"{old_value}"', f'"{new_value}"') + column_default_value = _replace_strings_in_quotes(column_default_value, enum_values_to_rename) - suffix = "[]" if is_array else "" - if schema: - return f"{column_default_value}::{schema}.{enum_name}{suffix}" - return f"{column_default_value}::{enum_name}{suffix}" + return f"{column_default_value}::{new_enum}" + + +def _replace_strings_in_quotes( + old_default: str, + enum_values_to_rename: List[Tuple[str, str]], +) -> str: + for old_value, new_value in enum_values_to_rename: + old_default = old_default.replace(f"'{old_value}'", f"'{new_value}'") + old_default = old_default.replace(f'"{old_value}"', f'"{new_value}"') + return old_default diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..ca67328 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,28 @@ +version: "3.8" + +services: + run-tests: +# entrypoint: pytest + build: . + stdin_open: true + tty: true + command: + - pytest + environment: + DATABASE_URI: postgresql://test_user:test_password@db:5432/test_db + depends_on: + - db + links: + - "db:database" + db: + image: postgres:12 + environment: + POSTGRES_DB: "test_db" + POSTGRES_USER: "test_user" + POSTGRES_PASSWORD: "test_password" + PGUSER: "postgres" + + ports: + - "5432:5432" + volumes: + - ./api/db/postgres-test-data:/var/lib/postgresql/data diff --git a/pyproject.toml b/pyproject.toml index 3907f4d..258e1f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "alembic-postgresql-enum" -version = "1.2.0" +version = "1.3.0" description = "Alembic autogenerate support for creation, alteration and deletion of enums" authors = ["RustyGuard"] license = "MIT" diff --git a/tests/README.md b/tests/README.md index e08ea1d..3740695 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,6 +1,17 @@ # How to run tests -Create database for testing +# With `docker compose` + +Just run: +```commandline +docker compose up --build --exit-code-from run-tests +``` + +# Manually + +## Create database + +Start postgres through docker compose: ## Env variables @@ -24,4 +35,4 @@ pip install -R tests/requirements.txt Run tests ``` pytest -``` \ No newline at end of file +``` diff --git a/tests/sync_enum_values/test_rename_default_if_required.py b/tests/sync_enum_values/test_rename_default_if_required.py index a76804e..752d7e1 100644 --- a/tests/sync_enum_values/test_rename_default_if_required.py +++ b/tests/sync_enum_values/test_rename_default_if_required.py @@ -31,3 +31,36 @@ def test_array_default_value_with_schema(): old_default_value = """'{}'::test.order_status_old[]""" assert rename_default_if_required("test", old_default_value, "order_status", []) == """'{}'::test.order_status[]""" + + +def test_caps_array_default_value_without_schema(): + old_default_value = """ARRAY['A'::my_old_enum, 'B'::my_old_enum]""" + + assert ( + rename_default_if_required("test", old_default_value, "my_enum", []) + == """ARRAY['A'::test.my_enum, 'B'::test.my_enum]""" + ) + + +def test_caps_array_default_value_with_schema(): + old_default_value = """ARRAY['A'::test.my_old_enum, 'B'::test.my_old_enum]""" + + assert ( + rename_default_if_required("test", old_default_value, "my_enum", []) + == """ARRAY['A'::test.my_enum, 'B'::test.my_enum]""" + ) + + +def test_caps_array_another_default_value_without_schema(): + old_default_value = """ARRAY['A'::my_old_enum, 'B'::my_old_enum]""" + + assert ( + rename_default_if_required("test", old_default_value, "my_enum", []) + == """ARRAY['A'::test.my_enum, 'B'::test.my_enum]""" + ) + + +def test_caps_array_another_default_value_with_schema(): + old_default_value = """ARRAY['A', 'B']::test.my_old_enum[]""" + + assert rename_default_if_required("test", old_default_value, "my_enum", []) == """ARRAY['A', 'B']::test.my_enum[]""" diff --git a/tests/sync_enum_values/test_run_array_new_column.py b/tests/sync_enum_values/test_run_array_new_column.py new file mode 100644 index 0000000..26d5451 --- /dev/null +++ b/tests/sync_enum_values/test_run_array_new_column.py @@ -0,0 +1,80 @@ +from enum import Enum +from typing import TYPE_CHECKING + +import sqlalchemy +from sqlalchemy import MetaData, Table, Column, insert +from sqlalchemy.dialects import postgresql + +from tests.base.run_migration_test_abc import CompareAndRunTestCase + +if TYPE_CHECKING: + from sqlalchemy import Connection + + +class OldEnum(Enum): + A = "a" + B = "b" + + +class NewEnum(Enum): + A = "a" + B = "b" + C = "c" + + +class TestNewArrayColumnColumn(CompareAndRunTestCase): + def get_database_schema(self) -> MetaData: + database_schema = MetaData() + Table("a", database_schema) # , Column("value", postgresql.ARRAY(postgresql.ENUM(OldEnum))) + Table( + "b", + database_schema, + Column( + "value", + postgresql.ARRAY(postgresql.ENUM(OldEnum, name="my_enum")), + server_default=sqlalchemy.text("ARRAY['A', 'B']::my_enum[]"), + ), + ) + return database_schema + + def get_target_schema(self) -> MetaData: + target_schema = MetaData() + Table( + "a", + target_schema, + Column( + "value", + postgresql.ARRAY(postgresql.ENUM(NewEnum, name="my_enum")), + server_default=sqlalchemy.text("ARRAY['A', 'B']::my_enum[]"), + ), + ) + Table( + "b", + target_schema, + Column( + "value", + postgresql.ARRAY(postgresql.ENUM(NewEnum, name="my_enum")), + server_default=sqlalchemy.text("ARRAY['A', 'B']::my_enum[]"), + ), + ) + return target_schema + + def get_expected_upgrade(self) -> str: + return """ + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('a', sa.Column('value', postgresql.ARRAY(postgresql.ENUM('A', 'B', 'C', name='my_enum', create_type=False)), server_default=sa.text("ARRAY['A', 'B']::my_enum[]"), nullable=True)) + op.sync_enum_values('public', 'my_enum', ['A', 'B', 'C'], + [TableReference(table_schema='public', table_name='a', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A', 'B']::my_enum[]"), TableReference(table_schema='public', table_name='b', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A'::my_enum, 'B'::my_enum]")], + enum_values_to_rename=[]) + # ### end Alembic commands ### + """ + + def get_expected_downgrade(self) -> str: + return """ + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values('public', 'my_enum', ['A', 'B'], + [TableReference(table_schema='public', table_name='a', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A', 'B']::my_enum[]"), TableReference(table_schema='public', table_name='b', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A'::my_enum, 'B'::my_enum]")], + enum_values_to_rename=[]) + op.drop_column('a', 'value') + # ### end Alembic commands ### + """