From 8bcce1093c3fe9b54d7e756ec63cdcf06fbded6e Mon Sep 17 00:00:00 2001 From: "artem.golovin" Date: Sat, 29 Jun 2024 12:47:43 +0400 Subject: [PATCH 01/10] Fix error when adding column that uses existing changing enum --- alembic_postgresql_enum/compare_dispatch.py | 4 +- .../get_enum_data/declared_enums.py | 52 ++++++++++- .../sql_commands/column_default.py | 40 ++++++-- .../test_rename_default_if_required.py | 33 +++++++ .../test_run_array_new_column.py | 91 +++++++++++++++++++ 5 files changed, 208 insertions(+), 12 deletions(-) create mode 100644 tests/sync_enum_values/test_run_array_new_column.py diff --git a/alembic_postgresql_enum/compare_dispatch.py b/alembic_postgresql_enum/compare_dispatch.py index 1f40172..53ccba4 100644 --- a/alembic_postgresql_enum/compare_dispatch.py +++ b/alembic_postgresql_enum/compare_dispatch.py @@ -61,7 +61,9 @@ def compare_enums( schema = default_schema definitions = get_defined_enums(autogen_context.connection, schema) - declarations = get_declared_enums(autogen_context.metadata, schema, default_schema, autogen_context.connection) + declarations = get_declared_enums( + autogen_context.metadata, schema, default_schema, autogen_context.connection, upgrade_ops + ) create_new_enums(definitions, declarations.enum_values, schema, upgrade_ops) diff --git a/alembic_postgresql_enum/get_enum_data/declared_enums.py b/alembic_postgresql_enum/get_enum_data/declared_enums.py index 75706b8..77e2d61 100644 --- a/alembic_postgresql_enum/get_enum_data/declared_enums.py +++ b/alembic_postgresql_enum/get_enum_data/declared_enums.py @@ -1,9 +1,9 @@ from collections import defaultdict -from enum import Enum -from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast +from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast, Dict import sqlalchemy -from sqlalchemy import MetaData +from alembic.operations.ops import UpgradeOps, ModifyTableOps, AddColumnOp, CreateTableOp +from sqlalchemy import MetaData, Column from sqlalchemy.dialects import postgresql from alembic_postgresql_enum.sql_commands.column_default import get_column_default @@ -45,11 +45,49 @@ def column_type_is_enum(column_type: Any) -> bool: return False +def get_just_added_defaults( + upgrade_ops: Union[UpgradeOps, None], default_schema: str +) -> Dict[Tuple[str, str, str], str]: + """Get all server defaults that will be added in current migration""" + if upgrade_ops is None: + return {} + + new_server_defaults = {} + + for operations_group in upgrade_ops.ops: + if isinstance(operations_group, ModifyTableOps): + for operation in operations_group.ops: + if isinstance(operation, AddColumnOp): + try: + if operation.column.server_default is None: + continue + new_server_defaults[ + operation.schema or default_schema, operation.table_name, operation.column.name + ] = operation.column.server_default.arg.text + except AttributeError: + pass + + elif isinstance(operations_group, CreateTableOp): + for column in operations_group.columns: + if isinstance(column, Column): + try: + if column.server_default is None: + continue + new_server_defaults[column.table.schema or default_schema, column.table.name, column.name] = ( + column.server_default.arg.text + ) + except AttributeError: + pass + + return new_server_defaults + + def get_declared_enums( metadata: Union[MetaData, List[MetaData]], schema: str, default_schema: str, connection: "Connection", + upgrade_ops: Union[UpgradeOps, None] = None, ) -> DeclaredEnumValues: """ Return a dict mapping SQLAlchemy declared enumeration types to the set of their values @@ -62,6 +100,8 @@ def get_declared_enums( Default schema name, likely will be "public" :param connection: Database connection + :param upgrade_ops: + Upgrade operations in current migration :returns DeclaredEnumValues: enum_values: { "my_enum": tuple(["a", "b", "c"]), @@ -75,6 +115,10 @@ def get_declared_enums( enum_name_to_values = dict() enum_name_to_table_references: defaultdict[str, Set[TableReference]] = defaultdict(set) + just_added_defaults = get_just_added_defaults(upgrade_ops, default_schema) + + # assert just_added_defaults == {}, just_added_defaults + if isinstance(metadata, list): metadata_list = metadata else: @@ -103,6 +147,8 @@ def get_declared_enums( table_schema = table.schema or default_schema column_default = get_column_default(connection, table_schema, table.name, column.name) + if (table_schema, table.name, column.name) in just_added_defaults: + column_default = just_added_defaults[table_schema, table.name, column.name] enum_name_to_table_references[column_type.name].add( # type: ignore[attr-defined] TableReference( table_schema=table_schema, diff --git a/alembic_postgresql_enum/sql_commands/column_default.py b/alembic_postgresql_enum/sql_commands/column_default.py index cef79ba..0d21175 100644 --- a/alembic_postgresql_enum/sql_commands/column_default.py +++ b/alembic_postgresql_enum/sql_commands/column_default.py @@ -1,3 +1,4 @@ +import re from typing import TYPE_CHECKING, Union, List, Tuple import sqlalchemy @@ -54,15 +55,38 @@ def rename_default_if_required( enum_name: str, enum_values_to_rename: List[Tuple[str, str]], ) -> str: - is_array = default_value.endswith("[]") + if schema: + new_enum = f"{schema}.{enum_name}" + else: + new_enum = enum_name + + if default_value.startswith("ARRAY["): + column_default_value = _replace_strings_in_quotes(default_value, enum_values_to_rename) + column_default_value = re.sub(r"::[.\w]+", f"::{new_enum}", column_default_value) + return column_default_value + + if default_value.endswith("[]"): + + # remove old type postfix + column_default_value = default_value[: default_value.find("::")] + + column_default_value = _replace_strings_in_quotes(column_default_value, enum_values_to_rename) + + return f"{column_default_value}::{new_enum}[]" + # remove old type postfix column_default_value = default_value[: default_value.find("::")] - for old_value, new_value in enum_values_to_rename: - column_default_value = column_default_value.replace(f"'{old_value}'", f"'{new_value}'") - column_default_value = column_default_value.replace(f'"{old_value}"', f'"{new_value}"') + column_default_value = _replace_strings_in_quotes(column_default_value, enum_values_to_rename) - suffix = "[]" if is_array else "" - if schema: - return f"{column_default_value}::{schema}.{enum_name}{suffix}" - return f"{column_default_value}::{enum_name}{suffix}" + return f"{column_default_value}::{new_enum}" + + +def _replace_strings_in_quotes( + old_default: str, + enum_values_to_rename: List[Tuple[str, str]], +) -> str: + for old_value, new_value in enum_values_to_rename: + old_default = old_default.replace(f"'{old_value}'", f"'{new_value}'") + old_default = old_default.replace(f'"{old_value}"', f'"{new_value}"') + return old_default diff --git a/tests/sync_enum_values/test_rename_default_if_required.py b/tests/sync_enum_values/test_rename_default_if_required.py index a76804e..752d7e1 100644 --- a/tests/sync_enum_values/test_rename_default_if_required.py +++ b/tests/sync_enum_values/test_rename_default_if_required.py @@ -31,3 +31,36 @@ def test_array_default_value_with_schema(): old_default_value = """'{}'::test.order_status_old[]""" assert rename_default_if_required("test", old_default_value, "order_status", []) == """'{}'::test.order_status[]""" + + +def test_caps_array_default_value_without_schema(): + old_default_value = """ARRAY['A'::my_old_enum, 'B'::my_old_enum]""" + + assert ( + rename_default_if_required("test", old_default_value, "my_enum", []) + == """ARRAY['A'::test.my_enum, 'B'::test.my_enum]""" + ) + + +def test_caps_array_default_value_with_schema(): + old_default_value = """ARRAY['A'::test.my_old_enum, 'B'::test.my_old_enum]""" + + assert ( + rename_default_if_required("test", old_default_value, "my_enum", []) + == """ARRAY['A'::test.my_enum, 'B'::test.my_enum]""" + ) + + +def test_caps_array_another_default_value_without_schema(): + old_default_value = """ARRAY['A'::my_old_enum, 'B'::my_old_enum]""" + + assert ( + rename_default_if_required("test", old_default_value, "my_enum", []) + == """ARRAY['A'::test.my_enum, 'B'::test.my_enum]""" + ) + + +def test_caps_array_another_default_value_with_schema(): + old_default_value = """ARRAY['A', 'B']::test.my_old_enum[]""" + + assert rename_default_if_required("test", old_default_value, "my_enum", []) == """ARRAY['A', 'B']::test.my_enum[]""" diff --git a/tests/sync_enum_values/test_run_array_new_column.py b/tests/sync_enum_values/test_run_array_new_column.py new file mode 100644 index 0000000..8114bc6 --- /dev/null +++ b/tests/sync_enum_values/test_run_array_new_column.py @@ -0,0 +1,91 @@ +from enum import Enum +from typing import TYPE_CHECKING + +import sqlalchemy +from sqlalchemy import MetaData, Table, Column, insert +from sqlalchemy.dialects import postgresql + +from tests.base.run_migration_test_abc import CompareAndRunTestCase + +if TYPE_CHECKING: + from sqlalchemy import Connection + + +class OldEnum(Enum): + A = "a" + B = "b" + + +class NewEnum(Enum): + A = "a" + B = "b" + C = "c" + + +class TestNewArrayColumnColumn(CompareAndRunTestCase): + def get_database_schema(self) -> MetaData: + database_schema = MetaData() + Table("a", database_schema) # , Column("value", postgresql.ARRAY(postgresql.ENUM(OldEnum))) + Table( + "b", + database_schema, + Column( + "value", + postgresql.ARRAY(postgresql.ENUM(OldEnum, name="my_enum")), + server_default=sqlalchemy.text("ARRAY['A', 'B']::my_enum[]"), + ), + ) + return database_schema + + def get_target_schema(self) -> MetaData: + target_schema = MetaData() + Table( + "a", + target_schema, + Column( + "value", + postgresql.ARRAY(postgresql.ENUM(NewEnum, name="my_enum")), + server_default=sqlalchemy.text("ARRAY['A', 'B']::my_enum[]"), + ), + ) + Table( + "b", + target_schema, + Column( + "value", + postgresql.ARRAY(postgresql.ENUM(NewEnum, name="my_enum")), + server_default=sqlalchemy.text("ARRAY['A', 'B']::my_enum[]"), + ), + ) + return target_schema + + def insert_migration_data(self, connection: "Connection", database_schema: MetaData) -> None: + a_table = database_schema.tables["a"] + connection.execute( + insert(a_table).values( + [ + {}, + {}, + ] + ) + ) + + def get_expected_upgrade(self) -> str: + return """ + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('a', sa.Column('value', postgresql.ARRAY(postgresql.ENUM('A', 'B', 'C', name='my_enum', create_type=False)), server_default=sa.text("ARRAY['A', 'B']::my_enum[]"), nullable=True)) + op.sync_enum_values('public', 'my_enum', ['A', 'B', 'C'], + [TableReference(table_schema='public', table_name='b', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A'::my_enum, 'B'::my_enum]"), TableReference(table_schema='public', table_name='a', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A', 'B']::my_enum[]")], + enum_values_to_rename=[]) + # ### end Alembic commands ### + """ + + def get_expected_downgrade(self) -> str: + return """ + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values('public', 'my_enum', ['A', 'B'], + [TableReference(table_schema='public', table_name='b', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A'::my_enum, 'B'::my_enum]"), TableReference(table_schema='public', table_name='a', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A', 'B']::my_enum[]")], + enum_values_to_rename=[]) + op.drop_column('a', 'value') + # ### end Alembic commands ### + """ From 1c9e4f89b1521a75b840698a79680ee0bb70bdec Mon Sep 17 00:00:00 2001 From: "artem.golovin" Date: Sat, 29 Jun 2024 12:52:15 +0400 Subject: [PATCH 02/10] Remove insert_migration_data --- tests/sync_enum_values/test_run_array_new_column.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/sync_enum_values/test_run_array_new_column.py b/tests/sync_enum_values/test_run_array_new_column.py index 8114bc6..da2f671 100644 --- a/tests/sync_enum_values/test_run_array_new_column.py +++ b/tests/sync_enum_values/test_run_array_new_column.py @@ -59,17 +59,6 @@ def get_target_schema(self) -> MetaData: ) return target_schema - def insert_migration_data(self, connection: "Connection", database_schema: MetaData) -> None: - a_table = database_schema.tables["a"] - connection.execute( - insert(a_table).values( - [ - {}, - {}, - ] - ) - ) - def get_expected_upgrade(self) -> str: return """ # ### commands auto generated by Alembic - please adjust! ### From 8a7cf996fcd4eb1c424fcfe4b6dbfa9c47f6701b Mon Sep 17 00:00:00 2001 From: "artem.golovin" Date: Sat, 29 Jun 2024 13:00:49 +0400 Subject: [PATCH 03/10] Establish sync_enum_values TableReference order --- .../detection_of_changes/enum_alteration.py | 5 ++++- tests/sync_enum_values/test_run_array_new_column.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/alembic_postgresql_enum/detection_of_changes/enum_alteration.py b/alembic_postgresql_enum/detection_of_changes/enum_alteration.py index bc5c7ca..52cca2a 100644 --- a/alembic_postgresql_enum/detection_of_changes/enum_alteration.py +++ b/alembic_postgresql_enum/detection_of_changes/enum_alteration.py @@ -47,6 +47,9 @@ def sync_changed_enums( enum_name, list(old_values), list(new_values), - list(affected_columns), + sorted( + affected_columns, + key=lambda reference: (reference.table_schema, reference.table_name, reference.column_name), + ), ) upgrade_ops.ops.append(op) diff --git a/tests/sync_enum_values/test_run_array_new_column.py b/tests/sync_enum_values/test_run_array_new_column.py index da2f671..26d5451 100644 --- a/tests/sync_enum_values/test_run_array_new_column.py +++ b/tests/sync_enum_values/test_run_array_new_column.py @@ -64,7 +64,7 @@ def get_expected_upgrade(self) -> str: # ### commands auto generated by Alembic - please adjust! ### op.add_column('a', sa.Column('value', postgresql.ARRAY(postgresql.ENUM('A', 'B', 'C', name='my_enum', create_type=False)), server_default=sa.text("ARRAY['A', 'B']::my_enum[]"), nullable=True)) op.sync_enum_values('public', 'my_enum', ['A', 'B', 'C'], - [TableReference(table_schema='public', table_name='b', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A'::my_enum, 'B'::my_enum]"), TableReference(table_schema='public', table_name='a', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A', 'B']::my_enum[]")], + [TableReference(table_schema='public', table_name='a', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A', 'B']::my_enum[]"), TableReference(table_schema='public', table_name='b', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A'::my_enum, 'B'::my_enum]")], enum_values_to_rename=[]) # ### end Alembic commands ### """ @@ -73,7 +73,7 @@ def get_expected_downgrade(self) -> str: return """ # ### commands auto generated by Alembic - please adjust! ### op.sync_enum_values('public', 'my_enum', ['A', 'B'], - [TableReference(table_schema='public', table_name='b', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A'::my_enum, 'B'::my_enum]"), TableReference(table_schema='public', table_name='a', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A', 'B']::my_enum[]")], + [TableReference(table_schema='public', table_name='a', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A', 'B']::my_enum[]"), TableReference(table_schema='public', table_name='b', column_name='value', column_type=ColumnType.ARRAY, existing_server_default="ARRAY['A'::my_enum, 'B'::my_enum]")], enum_values_to_rename=[]) op.drop_column('a', 'value') # ### end Alembic commands ### From 39ec130d0ccada42e7347f318a8ad7af8b05e799 Mon Sep 17 00:00:00 2001 From: "artem.golovin" Date: Sat, 29 Jun 2024 13:04:13 +0400 Subject: [PATCH 04/10] Clean up assert --- alembic_postgresql_enum/get_enum_data/declared_enums.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/alembic_postgresql_enum/get_enum_data/declared_enums.py b/alembic_postgresql_enum/get_enum_data/declared_enums.py index 77e2d61..7e4062f 100644 --- a/alembic_postgresql_enum/get_enum_data/declared_enums.py +++ b/alembic_postgresql_enum/get_enum_data/declared_enums.py @@ -117,8 +117,6 @@ def get_declared_enums( just_added_defaults = get_just_added_defaults(upgrade_ops, default_schema) - # assert just_added_defaults == {}, just_added_defaults - if isinstance(metadata, list): metadata_list = metadata else: From f91dfd6f917f5521de97e6d8c5bfa92cea31b9b1 Mon Sep 17 00:00:00 2001 From: "artem.golovin" Date: Sat, 29 Jun 2024 13:05:00 +0400 Subject: [PATCH 05/10] Remove pull request action as it duplicates push --- .github/workflows/test_on_push.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test_on_push.yaml b/.github/workflows/test_on_push.yaml index ac2be5a..f0299e9 100644 --- a/.github/workflows/test_on_push.yaml +++ b/.github/workflows/test_on_push.yaml @@ -10,7 +10,6 @@ on: - tests/** - alembic_postgresql_enum/** - .github/workflows/test_on_push.yaml - pull_request: { } jobs: run_tests: From da6cfe36a3649bb534dce2dc18631f16ab87e5e4 Mon Sep 17 00:00:00 2001 From: "artem.golovin" Date: Sat, 29 Jun 2024 13:23:49 +0400 Subject: [PATCH 06/10] Draft: Run tests locally in docker --- Dockerfile | 10 ++++++++++ docker-compose.yml | 28 ++++++++++++++++++++++++++++ tests/README.md | 8 ++++++++ 3 files changed, 46 insertions(+) create mode 100644 Dockerfile create mode 100644 docker-compose.yml diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..68360c3 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,10 @@ +FROM python:latest + +COPY ./alembic_postgresql_enum ./alembic_postgresql_enum +COPY ./tests ./tests + +WORKDIR ./tests + +RUN pip install -r requirements.txt + +ENTRYPOINT pytest diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..ca67328 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,28 @@ +version: "3.8" + +services: + run-tests: +# entrypoint: pytest + build: . + stdin_open: true + tty: true + command: + - pytest + environment: + DATABASE_URI: postgresql://test_user:test_password@db:5432/test_db + depends_on: + - db + links: + - "db:database" + db: + image: postgres:12 + environment: + POSTGRES_DB: "test_db" + POSTGRES_USER: "test_user" + POSTGRES_PASSWORD: "test_password" + PGUSER: "postgres" + + ports: + - "5432:5432" + volumes: + - ./api/db/postgres-test-data:/var/lib/postgresql/data diff --git a/tests/README.md b/tests/README.md index e08ea1d..0f64df7 100644 --- a/tests/README.md +++ b/tests/README.md @@ -24,4 +24,12 @@ pip install -R tests/requirements.txt Run tests ``` pytest +``` + +# In progress + +To run tests just use: +```commandline +docker compose build +docker compose up ``` \ No newline at end of file From 78d35c217fca0c584644b48364b8b627dfbdca17 Mon Sep 17 00:00:00 2001 From: "artem.golovin" Date: Sat, 29 Jun 2024 21:12:56 +0400 Subject: [PATCH 07/10] Add command to run tests with docker compose --- tests/README.md | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/README.md b/tests/README.md index 0f64df7..3740695 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,6 +1,17 @@ # How to run tests -Create database for testing +# With `docker compose` + +Just run: +```commandline +docker compose up --build --exit-code-from run-tests +``` + +# Manually + +## Create database + +Start postgres through docker compose: ## Env variables @@ -25,11 +36,3 @@ Run tests ``` pytest ``` - -# In progress - -To run tests just use: -```commandline -docker compose build -docker compose up -``` \ No newline at end of file From 757d16b180890cef29b0a49df12fb7597b48ab2b Mon Sep 17 00:00:00 2001 From: "artem.golovin" Date: Sun, 30 Jun 2024 13:45:45 +0400 Subject: [PATCH 08/10] Check whether server_default is changed with alter_column --- alembic_postgresql_enum/get_enum_data/declared_enums.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/alembic_postgresql_enum/get_enum_data/declared_enums.py b/alembic_postgresql_enum/get_enum_data/declared_enums.py index 7e4062f..2e0cfa9 100644 --- a/alembic_postgresql_enum/get_enum_data/declared_enums.py +++ b/alembic_postgresql_enum/get_enum_data/declared_enums.py @@ -2,7 +2,7 @@ from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast, Dict import sqlalchemy -from alembic.operations.ops import UpgradeOps, ModifyTableOps, AddColumnOp, CreateTableOp +from alembic.operations.ops import UpgradeOps, ModifyTableOps, AddColumnOp, CreateTableOp, AlterColumnOp from sqlalchemy import MetaData, Column from sqlalchemy.dialects import postgresql @@ -66,6 +66,11 @@ def get_just_added_defaults( ] = operation.column.server_default.arg.text except AttributeError: pass + elif isinstance(operation, AlterColumnOp): + if operation.modify_server_default is not False: + new_server_defaults[ + operation.schema or default_schema, operation.table_name, operation.column_name + ] = operation.modify_server_default elif isinstance(operations_group, CreateTableOp): for column in operations_group.columns: From 75789b070dc5352ce39e5469dfc59846285ba10b Mon Sep 17 00:00:00 2001 From: "artem.golovin" Date: Sun, 30 Jun 2024 23:08:25 +0400 Subject: [PATCH 09/10] Review fixes --- alembic_postgresql_enum/compare_dispatch.py | 13 ++-- .../detection_of_changes/enum_alteration.py | 2 +- .../get_enum_data/declared_enums.py | 51 ++----------- .../get_default_from_alembic_ops.py | 73 +++++++++++++++++++ 4 files changed, 86 insertions(+), 53 deletions(-) create mode 100644 alembic_postgresql_enum/get_enum_data/get_default_from_alembic_ops.py diff --git a/alembic_postgresql_enum/compare_dispatch.py b/alembic_postgresql_enum/compare_dispatch.py index 53ccba4..ad5d384 100644 --- a/alembic_postgresql_enum/compare_dispatch.py +++ b/alembic_postgresql_enum/compare_dispatch.py @@ -32,6 +32,13 @@ def compare_enums( for each defined enum that has changed new entries when compared to its declared version. """ + assert ( + autogen_context.dialect is not None + and autogen_context.dialect.default_schema_name is not None + and autogen_context.connection is not None + and autogen_context.metadata is not None + ) + if autogen_context.dialect.name != "postgresql": log.warning( f"This library only supports postgresql, but you are using {autogen_context.dialect.name}, skipping" @@ -49,12 +56,6 @@ def compare_enums( if isinstance(operations_group, CreateTableOp) and operations_group.schema not in schema_names: schema_names.append(operations_group.schema) - assert ( - autogen_context.dialect is not None - and autogen_context.dialect.default_schema_name is not None - and autogen_context.connection is not None - and autogen_context.metadata is not None - ) for schema in schema_names: default_schema = autogen_context.dialect.default_schema_name if schema is None: diff --git a/alembic_postgresql_enum/detection_of_changes/enum_alteration.py b/alembic_postgresql_enum/detection_of_changes/enum_alteration.py index 52cca2a..1d92c6b 100644 --- a/alembic_postgresql_enum/detection_of_changes/enum_alteration.py +++ b/alembic_postgresql_enum/detection_of_changes/enum_alteration.py @@ -47,7 +47,7 @@ def sync_changed_enums( enum_name, list(old_values), list(new_values), - sorted( + sorted( # Sort references alphabetically for consistency of generated text affected_columns, key=lambda reference: (reference.table_schema, reference.table_name, reference.column_name), ), diff --git a/alembic_postgresql_enum/get_enum_data/declared_enums.py b/alembic_postgresql_enum/get_enum_data/declared_enums.py index 2e0cfa9..98e0200 100644 --- a/alembic_postgresql_enum/get_enum_data/declared_enums.py +++ b/alembic_postgresql_enum/get_enum_data/declared_enums.py @@ -1,11 +1,12 @@ from collections import defaultdict -from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast, Dict +from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast, Optional import sqlalchemy -from alembic.operations.ops import UpgradeOps, ModifyTableOps, AddColumnOp, CreateTableOp, AlterColumnOp -from sqlalchemy import MetaData, Column +from alembic.operations.ops import UpgradeOps +from sqlalchemy import MetaData from sqlalchemy.dialects import postgresql +from alembic_postgresql_enum.get_enum_data.get_default_from_alembic_ops import get_just_added_defaults from alembic_postgresql_enum.sql_commands.column_default import get_column_default if TYPE_CHECKING: @@ -45,54 +46,12 @@ def column_type_is_enum(column_type: Any) -> bool: return False -def get_just_added_defaults( - upgrade_ops: Union[UpgradeOps, None], default_schema: str -) -> Dict[Tuple[str, str, str], str]: - """Get all server defaults that will be added in current migration""" - if upgrade_ops is None: - return {} - - new_server_defaults = {} - - for operations_group in upgrade_ops.ops: - if isinstance(operations_group, ModifyTableOps): - for operation in operations_group.ops: - if isinstance(operation, AddColumnOp): - try: - if operation.column.server_default is None: - continue - new_server_defaults[ - operation.schema or default_schema, operation.table_name, operation.column.name - ] = operation.column.server_default.arg.text - except AttributeError: - pass - elif isinstance(operation, AlterColumnOp): - if operation.modify_server_default is not False: - new_server_defaults[ - operation.schema or default_schema, operation.table_name, operation.column_name - ] = operation.modify_server_default - - elif isinstance(operations_group, CreateTableOp): - for column in operations_group.columns: - if isinstance(column, Column): - try: - if column.server_default is None: - continue - new_server_defaults[column.table.schema or default_schema, column.table.name, column.name] = ( - column.server_default.arg.text - ) - except AttributeError: - pass - - return new_server_defaults - - def get_declared_enums( metadata: Union[MetaData, List[MetaData]], schema: str, default_schema: str, connection: "Connection", - upgrade_ops: Union[UpgradeOps, None] = None, + upgrade_ops: Optional[UpgradeOps] = None, ) -> DeclaredEnumValues: """ Return a dict mapping SQLAlchemy declared enumeration types to the set of their values diff --git a/alembic_postgresql_enum/get_enum_data/get_default_from_alembic_ops.py b/alembic_postgresql_enum/get_enum_data/get_default_from_alembic_ops.py new file mode 100644 index 0000000..785ecf3 --- /dev/null +++ b/alembic_postgresql_enum/get_enum_data/get_default_from_alembic_ops.py @@ -0,0 +1,73 @@ +from typing import Optional, Dict, Tuple + +from alembic.operations.ops import UpgradeOps, ModifyTableOps, AddColumnOp, AlterColumnOp, CreateTableOp +from sqlalchemy import Column + +SchemaName = str +TableName = str +ColumnName = str +ColumnLocation = Tuple[SchemaName, TableName, ColumnName] + + +def _get_default_from_add_column_op(op: AddColumnOp, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]: + if op.column.server_default is None: + raise AttributeError("No new server_default") + return ( + (op.schema or default_schema, op.table_name, op.column.name), + op.column.server_default.arg.text, # type: ignore[attr-defined] + ) + + +def _get_default_from_alter_column_op(op: AlterColumnOp, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]: + if op.modify_server_default is False: + raise AttributeError("No new server_default") + return (op.schema or default_schema, op.table_name, op.column_name), op.modify_server_default + + +def _get_default_from_column(column: Column, default_schema: str) -> Tuple[ColumnLocation, Optional[str]]: + if column.server_default is None: + raise AttributeError("No new server_default") + return ( + (column.table.schema or default_schema, column.table.name, column.name), + column.server_default.arg.text, # type: ignore[attr-defined] + ) + + +def get_just_added_defaults( + upgrade_ops: Optional[UpgradeOps], default_schema: str +) -> Dict[ColumnLocation, Optional[str]]: + """Get all server defaults that will be added in current migration""" + if upgrade_ops is None: + return {} + + new_server_defaults = {} + + for operations_group in upgrade_ops.ops: + if isinstance(operations_group, ModifyTableOps): + for operation in operations_group.ops: + if isinstance(operation, AddColumnOp): + try: + column_location, column_new_default = _get_default_from_add_column_op(operation, default_schema) + new_server_defaults[column_location] = column_new_default + except AttributeError: + pass + + elif isinstance(operation, AlterColumnOp): + try: + column_location, column_new_default = _get_default_from_alter_column_op( + operation, default_schema + ) + new_server_defaults[column_location] = column_new_default + except AttributeError: + pass + + elif isinstance(operations_group, CreateTableOp): + for column in operations_group.columns: + if isinstance(column, Column): + try: + column_location, column_new_default = _get_default_from_column(column, default_schema) + new_server_defaults[column_location] = column_new_default + except AttributeError: + pass + + return new_server_defaults From 0d046705c7f7b0d2a80bf4ca0d415d070aa33b24 Mon Sep 17 00:00:00 2001 From: Artem Golovin Date: Sat, 13 Jul 2024 12:09:35 +0400 Subject: [PATCH 10/10] Bump version to 1.3.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3907f4d..258e1f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "alembic-postgresql-enum" -version = "1.2.0" +version = "1.3.0" description = "Alembic autogenerate support for creation, alteration and deletion of enums" authors = ["RustyGuard"] license = "MIT"