diff --git a/tests/base/render_and_run.py b/tests/base/render_and_run.py index a4e5fe6..548d8d2 100644 --- a/tests/base/render_and_run.py +++ b/tests/base/render_and_run.py @@ -1,5 +1,5 @@ import textwrap -from typing import TYPE_CHECKING, Union, List +from typing import TYPE_CHECKING, Union, List, Optional import sqlalchemy from alembic import autogenerate @@ -20,6 +20,7 @@ def compare_and_run( *, expected_upgrade: str, expected_downgrade: str, + expected_imports: Optional[str], disable_running: bool = False, ): """Compares generated migration script is equal to expected_upgrade and expected_downgrade, then runs it""" @@ -37,6 +38,8 @@ def compare_and_run( expected_upgrade = textwrap.dedent(expected_upgrade).strip("\n ") expected_downgrade = textwrap.dedent(expected_downgrade).strip("\n ") + if expected_imports is not None: + assert template_args["imports"] == expected_imports assert upgrade_code == expected_upgrade, f"Got:\n{upgrade_code!r}\nExpected:\n{expected_upgrade!r}" assert downgrade_code == expected_downgrade, f"Got:\n{downgrade_code!r}\nExpected:\n{expected_downgrade!r}" diff --git a/tests/base/run_migration_test_abc.py b/tests/base/run_migration_test_abc.py index 64dfb76..b9fd33a 100644 --- a/tests/base/run_migration_test_abc.py +++ b/tests/base/run_migration_test_abc.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import alembic_postgresql_enum from alembic_postgresql_enum.configuration import Config, get_configuration @@ -33,6 +33,9 @@ def get_expected_upgrade(self) -> str: ... @abstractmethod def get_expected_downgrade(self) -> str: ... + def get_expected_imports(self) -> Optional[str]: + return None + def test_run(self, connection: "Connection"): old_config = get_configuration() alembic_postgresql_enum.set_configuration(self.config) @@ -47,6 +50,7 @@ def test_run(self, connection: "Connection"): target_schema, expected_upgrade=self.get_expected_upgrade(), expected_downgrade=self.get_expected_downgrade(), + expected_imports=self.get_expected_imports(), disable_running=self.disable_running, ) alembic_postgresql_enum.set_configuration(old_config) diff --git a/tests/sync_enum_values/test_array_column.py b/tests/sync_enum_values/test_array_column.py index 6bf7542..fced88b 100644 --- a/tests/sync_enum_values/test_array_column.py +++ b/tests/sync_enum_values/test_array_column.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from alembic import autogenerate from alembic.autogenerate import api @@ -63,7 +63,8 @@ def get_expected_downgrade(self) -> str: # ### end Alembic commands ### """ - "from alembic_postgresql_enum import ColumnType" "\nfrom alembic_postgresql_enum import TableReference" + def get_expected_imports(self) -> Optional[str]: + return "from alembic_postgresql_enum import ColumnType" "\nfrom alembic_postgresql_enum import TableReference" def test_add_new_enum_value_diff_tuple_with_array(connection: "Connection"):