From f8b18febe5885d87e4a912d1d07c92487d6703e6 Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 16 Feb 2024 21:20:56 +0400 Subject: [PATCH 1/4] Test for #63 --- tests/base/render_and_run.py | 4 +- tests/base/run_migration_test_abc.py | 39 +++++++++++++++++ tests/sync_enum_values/test_render.py | 60 +++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 tests/base/run_migration_test_abc.py diff --git a/tests/base/render_and_run.py b/tests/base/render_and_run.py index 7062847..ff39f60 100644 --- a/tests/base/render_and_run.py +++ b/tests/base/render_and_run.py @@ -7,7 +7,7 @@ from sqlalchemy import MetaData from sqlalchemy.dialects import postgresql -from alembic_postgresql_enum import ColumnType +from alembic_postgresql_enum import ColumnType, TableReference from tests.utils.migration_context import create_migration_context if TYPE_CHECKING: @@ -46,6 +46,7 @@ def compare_and_run( "sa": sqlalchemy, "postgresql": postgresql, "ColumnType": ColumnType, + "TableReference": TableReference, }, ) exec( @@ -55,5 +56,6 @@ def compare_and_run( "sa": sqlalchemy, "postgresql": postgresql, "ColumnType": ColumnType, + "TableReference": TableReference, }, ) diff --git a/tests/base/run_migration_test_abc.py b/tests/base/run_migration_test_abc.py new file mode 100644 index 0000000..b31a8b1 --- /dev/null +++ b/tests/base/run_migration_test_abc.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from tests.base.render_and_run import compare_and_run + +if TYPE_CHECKING: + from sqlalchemy import Connection +from sqlalchemy import MetaData + + +class CompareAndRunTestCase(ABC): + @abstractmethod + def get_database_schema(self) -> MetaData: + ... + + @abstractmethod + def get_target_schema(self) -> MetaData: + ... + + @abstractmethod + def get_expected_upgrade(self) -> str: + ... + + @abstractmethod + def get_expected_downgrade(self) -> str: + ... + + def test_run(self, connection: "Connection"): + database_schema = self.get_database_schema() + target_schema = self.get_target_schema() + + database_schema.create_all(connection) + + compare_and_run( + connection, + target_schema, + expected_upgrade=self.get_expected_upgrade(), + expected_downgrade=self.get_expected_downgrade(), + ) diff --git a/tests/sync_enum_values/test_render.py b/tests/sync_enum_values/test_render.py index 4a1ed99..f100cae 100644 --- a/tests/sync_enum_values/test_render.py +++ b/tests/sync_enum_values/test_render.py @@ -1,14 +1,19 @@ +import enum from typing import TYPE_CHECKING from alembic import autogenerate from alembic.autogenerate import api from alembic.operations import ops +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import declarative_base from alembic_postgresql_enum.get_enum_data import TableReference from alembic_postgresql_enum.operations import SyncEnumValuesOp +from tests.base.run_migration_test_abc import CompareAndRunTestCase if TYPE_CHECKING: from sqlalchemy import Connection +from sqlalchemy import MetaData, Column, Integer from tests.schemas import ( get_schema_with_enum_variants, @@ -161,3 +166,58 @@ def test_rename_enum_value_diff_tuple(connection: "Connection"): assert affected_columns == [ TableReference(table_schema=DEFAULT_SCHEMA, table_name=USER_TABLE_NAME, column_name=USER_STATUS_COLUMN_NAME) ] + + +class TestServerDefault(CompareAndRunTestCase): + def get_database_schema(self) -> MetaData: + schema = MetaData() + + Base = declarative_base(metadata=schema) + + class MyEnum(enum.Enum): + one = 1 + two = 2 + three = 3 + + class ExampleTable(Base): + __tablename__ = "example_table" + test_field = Column(Integer, primary_key=True, autoincrement=False) + enum_field = Column(postgresql.ENUM(MyEnum, name="my_enum"), server_default=MyEnum.one.name) + + return schema + + def get_target_schema(self) -> MetaData: + schema = MetaData() + + Base = declarative_base(metadata=schema) + + class NewMyEnum(enum.Enum): + one = 1 + two = 2 + three = 3 + four = 4 # added + + class ExampleTable(Base): + __tablename__ = "example_table" + test_field = Column(Integer, primary_key=True, autoincrement=False) + enum_field = Column(postgresql.ENUM(NewMyEnum, name="my_enum"), server_default=NewMyEnum.one.name) + + return schema + + def get_expected_upgrade(self) -> str: + return """ + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values('public', 'my_enum', ['one', 'two', 'three', 'four'], + [TableReference(table_schema='public', table_name='example_table', column_name='enum_field', existing_server_default="'one'::my_enum")], + enum_values_to_rename=[]) + # ### end Alembic commands ### + """ + + def get_expected_downgrade(self) -> str: + return """ + # ### commands auto generated by Alembic - please adjust! ### + op.sync_enum_values('public', 'my_enum', ['one', 'two', 'three'], + [TableReference(table_schema='public', table_name='example_table', column_name='enum_field', existing_server_default="'one'::my_enum")], + enum_values_to_rename=[]) + # ### end Alembic commands ### + """ From ff2b0628f072fc4b4c2d70469892be64e59bf001 Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 16 Feb 2024 21:21:01 +0400 Subject: [PATCH 2/4] Fix for #63 --- alembic_postgresql_enum/get_enum_data/declared_enums.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alembic_postgresql_enum/get_enum_data/declared_enums.py b/alembic_postgresql_enum/get_enum_data/declared_enums.py index e21a516..611da42 100644 --- a/alembic_postgresql_enum/get_enum_data/declared_enums.py +++ b/alembic_postgresql_enum/get_enum_data/declared_enums.py @@ -101,7 +101,7 @@ def get_declared_enums( enum_name_to_values[column_type.name] = get_enum_values(column_type) table_schema = table.schema or default_schema - column_default = get_column_default(connection, table.schema, table.name, column.name) + column_default = get_column_default(connection, table_schema, table.name, column.name) enum_name_to_table_references[column_type.name].add( TableReference( table_schema=table_schema, From c6a05a9fb6f1234b790000153646bf8d85492bd3 Mon Sep 17 00:00:00 2001 From: rusty Date: Fri, 16 Feb 2024 21:22:44 +0400 Subject: [PATCH 3/4] Fix lint --- tests/base/run_migration_test_abc.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/base/run_migration_test_abc.py b/tests/base/run_migration_test_abc.py index b31a8b1..fb2b7cb 100644 --- a/tests/base/run_migration_test_abc.py +++ b/tests/base/run_migration_test_abc.py @@ -10,20 +10,16 @@ class CompareAndRunTestCase(ABC): @abstractmethod - def get_database_schema(self) -> MetaData: - ... + def get_database_schema(self) -> MetaData: ... @abstractmethod - def get_target_schema(self) -> MetaData: - ... + def get_target_schema(self) -> MetaData: ... @abstractmethod - def get_expected_upgrade(self) -> str: - ... + def get_expected_upgrade(self) -> str: ... @abstractmethod - def get_expected_downgrade(self) -> str: - ... + def get_expected_downgrade(self) -> str: ... def test_run(self, connection: "Connection"): database_schema = self.get_database_schema() From 69e75e31480ce0857df6e55bd2e0fc6f19f73bd3 Mon Sep 17 00:00:00 2001 From: Artem Golovin Date: Fri, 16 Feb 2024 21:27:28 +0400 Subject: [PATCH 4/4] Bump version to 1.1.2 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 28bface..60fec30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "alembic-postgresql-enum" -version = "1.1.1" +version = "1.1.2" description = "Alembic autogenerate support for creation, alteration and deletion of enums" authors = ["RustyGuard"] license = "MIT"