From 38b47afe0008154434d3f06d851557c30afe31f2 Mon Sep 17 00:00:00 2001 From: rusty Date: Sun, 17 Mar 2024 19:59:34 +0400 Subject: [PATCH 1/4] Fix all mypy errors --- alembic_postgresql_enum/add_create_type_false.py | 11 ++++------- .../add_postgres_using_to_text.py | 1 + alembic_postgresql_enum/compare_dispatch.py | 6 ++++++ alembic_postgresql_enum/connection.py | 3 ++- .../get_enum_data/declared_enums.py | 11 ++++++----- alembic_postgresql_enum/get_enum_data/types.py | 4 ++-- alembic_postgresql_enum/operations/create_enum.py | 1 + alembic_postgresql_enum/operations/drop_enum.py | 1 + tests/base/run_migration_test_abc.py | 12 ++++++++---- 9 files changed, 31 insertions(+), 19 deletions(-) diff --git a/alembic_postgresql_enum/add_create_type_false.py b/alembic_postgresql_enum/add_create_type_false.py index ea67246..9c90230 100644 --- a/alembic_postgresql_enum/add_create_type_false.py +++ b/alembic_postgresql_enum/add_create_type_false.py @@ -62,14 +62,10 @@ def add_create_type_false(upgrade_ops: UpgradeOps): if isinstance(operations_group, ModifyTableOps): for operation in operations_group.ops: if isinstance(operation, AddColumnOp): - column: Column = operation.column - - inject_repr_into_enums(column) - + inject_repr_into_enums(operation.column) elif isinstance(operation, DropColumnOp): - column: Column = operation._reverse.column - - inject_repr_into_enums(column) + assert operation._reverse is not None + inject_repr_into_enums(operation._reverse.column) elif isinstance(operations_group, CreateTableOp): for column in operations_group.columns: @@ -77,6 +73,7 @@ def add_create_type_false(upgrade_ops: UpgradeOps): inject_repr_into_enums(column) elif isinstance(operations_group, DropTableOp): + assert operations_group._reverse is not None for column in operations_group._reverse.columns: if isinstance(column, Column): inject_repr_into_enums(column) diff --git a/alembic_postgresql_enum/add_postgres_using_to_text.py b/alembic_postgresql_enum/add_postgres_using_to_text.py index 8e3116e..adf33ba 100644 --- a/alembic_postgresql_enum/add_postgres_using_to_text.py +++ b/alembic_postgresql_enum/add_postgres_using_to_text.py @@ -39,6 +39,7 @@ def _postgres_using_alter_column(autogen_context: AutogenContext, op: ops.AlterC def add_postgres_using_to_alter_operation(op: AlterColumnOp): + assert op.modify_type is not None op.kw["postgresql_using"] = f"{op.column_name}::{op.modify_type.name}" log.info("postgresql_using added to %r.%r alteration", op.table_name, op.column_name) op.__class__ = PostgresUsingAlterColumnOp diff --git a/alembic_postgresql_enum/compare_dispatch.py b/alembic_postgresql_enum/compare_dispatch.py index a61a90f..2efb9c6 100644 --- a/alembic_postgresql_enum/compare_dispatch.py +++ b/alembic_postgresql_enum/compare_dispatch.py @@ -39,6 +39,12 @@ def compare_enums( if isinstance(operations_group, CreateTableOp) and operations_group.schema not in schema_names: schema_names.append(operations_group.schema) + assert ( + autogen_context.dialect is not None + and autogen_context.dialect.default_schema_name is not None + and autogen_context.connection is not None + and autogen_context.metadata is not None + ) for schema in schema_names: default_schema = autogen_context.dialect.default_schema_name if schema is None: diff --git a/alembic_postgresql_enum/connection.py b/alembic_postgresql_enum/connection.py index e09b7ba..02d270d 100644 --- a/alembic_postgresql_enum/connection.py +++ b/alembic_postgresql_enum/connection.py @@ -1,10 +1,11 @@ from contextlib import contextmanager +from typing import Iterator import sqlalchemy @contextmanager -def get_connection(operations) -> sqlalchemy.engine.Connection: +def get_connection(operations) -> Iterator[sqlalchemy.engine.Connection]: """ SQLAlchemy 2.0 changes the operation binding location; bridge function to support both 1.x and 2.x. diff --git a/alembic_postgresql_enum/get_enum_data/declared_enums.py b/alembic_postgresql_enum/get_enum_data/declared_enums.py index 611da42..75706b8 100644 --- a/alembic_postgresql_enum/get_enum_data/declared_enums.py +++ b/alembic_postgresql_enum/get_enum_data/declared_enums.py @@ -1,5 +1,6 @@ from collections import defaultdict -from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING +from enum import Enum +from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast import sqlalchemy from sqlalchemy import MetaData @@ -93,16 +94,16 @@ def get_declared_enums( if not column_type_is_enum(column_type): continue - column_type_schema = column_type.schema or default_schema + column_type_schema = column_type.schema or default_schema # type: ignore[attr-defined] if column_type_schema != schema: continue - if column_type.name not in enum_name_to_values: - enum_name_to_values[column_type.name] = get_enum_values(column_type) + if column_type.name not in enum_name_to_values: # type: ignore[attr-defined] + enum_name_to_values[column_type.name] = get_enum_values(cast(sqlalchemy.Enum, column_type)) # type: ignore[attr-defined] table_schema = table.schema or default_schema column_default = get_column_default(connection, table_schema, table.name, column.name) - enum_name_to_table_references[column_type.name].add( + enum_name_to_table_references[column_type.name].add( # type: ignore[attr-defined] TableReference( table_schema=table_schema, table_name=table.name, diff --git a/alembic_postgresql_enum/get_enum_data/types.py b/alembic_postgresql_enum/get_enum_data/types.py index 42e299e..cb7cb11 100644 --- a/alembic_postgresql_enum/get_enum_data/types.py +++ b/alembic_postgresql_enum/get_enum_data/types.py @@ -20,9 +20,9 @@ def __repr__(self): class TableReference: table_name: str column_name: str - table_schema: Optional[str] = Unspecified # 'Unspecified' default is for migrations from older versions + table_schema: Optional[str] = Unspecified # type: ignore[assignment] # 'Unspecified' default is for migrations from older versions column_type: ColumnType = ColumnType.COMMON - existing_server_default: str = None + existing_server_default: str | None = None def __repr__(self): result_str = "TableReference(" diff --git a/alembic_postgresql_enum/operations/create_enum.py b/alembic_postgresql_enum/operations/create_enum.py index 9b4afed..6d0ca46 100644 --- a/alembic_postgresql_enum/operations/create_enum.py +++ b/alembic_postgresql_enum/operations/create_enum.py @@ -19,6 +19,7 @@ def reverse(self): @alembic.autogenerate.render.renderers.dispatch_for(CreateEnumOp) def render_create_enum_op(autogen_context: AutogenContext, op: CreateEnumOp): + assert autogen_context.dialect is not None if op.schema != autogen_context.dialect.default_schema_name: return f""" sa.Enum({', '.join(map(repr, op.enum_values))}, name='{op.name}', schema='{op.schema}').create(op.get_bind()) diff --git a/alembic_postgresql_enum/operations/drop_enum.py b/alembic_postgresql_enum/operations/drop_enum.py index 6872482..1f7f2ff 100644 --- a/alembic_postgresql_enum/operations/drop_enum.py +++ b/alembic_postgresql_enum/operations/drop_enum.py @@ -19,6 +19,7 @@ def reverse(self): @alembic.autogenerate.render.renderers.dispatch_for(DropEnumOp) def render_drop_enum_op(autogen_context: AutogenContext, op: DropEnumOp): + assert autogen_context.dialect is not None if op.schema != autogen_context.dialect.default_schema_name: return f""" sa.Enum({', '.join(map(repr, op.enum_values))}, name='{op.name}', schema='{op.schema}').drop(op.get_bind()) diff --git a/tests/base/run_migration_test_abc.py b/tests/base/run_migration_test_abc.py index fb2b7cb..b31a8b1 100644 --- a/tests/base/run_migration_test_abc.py +++ b/tests/base/run_migration_test_abc.py @@ -10,16 +10,20 @@ class CompareAndRunTestCase(ABC): @abstractmethod - def get_database_schema(self) -> MetaData: ... + def get_database_schema(self) -> MetaData: + ... @abstractmethod - def get_target_schema(self) -> MetaData: ... + def get_target_schema(self) -> MetaData: + ... @abstractmethod - def get_expected_upgrade(self) -> str: ... + def get_expected_upgrade(self) -> str: + ... @abstractmethod - def get_expected_downgrade(self) -> str: ... + def get_expected_downgrade(self) -> str: + ... def test_run(self, connection: "Connection"): database_schema = self.get_database_schema() From 832aa22c0f8a2b51dd1bd449f61568916994281d Mon Sep 17 00:00:00 2001 From: rusty Date: Sun, 17 Mar 2024 19:59:43 +0400 Subject: [PATCH 2/4] Add py.typed marker --- alembic_postgresql_enum/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 alembic_postgresql_enum/py.typed diff --git a/alembic_postgresql_enum/py.typed b/alembic_postgresql_enum/py.typed new file mode 100644 index 0000000..e69de29 From ebba945b4b123d894624085fd3b8e4ad767eb464 Mon Sep 17 00:00:00 2001 From: rusty Date: Sun, 17 Mar 2024 20:03:00 +0400 Subject: [PATCH 3/4] Fix black --- tests/base/run_migration_test_abc.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/base/run_migration_test_abc.py b/tests/base/run_migration_test_abc.py index b31a8b1..fb2b7cb 100644 --- a/tests/base/run_migration_test_abc.py +++ b/tests/base/run_migration_test_abc.py @@ -10,20 +10,16 @@ class CompareAndRunTestCase(ABC): @abstractmethod - def get_database_schema(self) -> MetaData: - ... + def get_database_schema(self) -> MetaData: ... @abstractmethod - def get_target_schema(self) -> MetaData: - ... + def get_target_schema(self) -> MetaData: ... @abstractmethod - def get_expected_upgrade(self) -> str: - ... + def get_expected_upgrade(self) -> str: ... @abstractmethod - def get_expected_downgrade(self) -> str: - ... + def get_expected_downgrade(self) -> str: ... def test_run(self, connection: "Connection"): database_schema = self.get_database_schema() From 074cfaf89e85bf98f9bb9ec63399ab9a9a6dd42f Mon Sep 17 00:00:00 2001 From: rusty Date: Sun, 17 Mar 2024 20:09:06 +0400 Subject: [PATCH 4/4] Fix type hint for older python version --- alembic_postgresql_enum/get_enum_data/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alembic_postgresql_enum/get_enum_data/types.py b/alembic_postgresql_enum/get_enum_data/types.py index cb7cb11..d71bd7f 100644 --- a/alembic_postgresql_enum/get_enum_data/types.py +++ b/alembic_postgresql_enum/get_enum_data/types.py @@ -22,7 +22,7 @@ class TableReference: column_name: str table_schema: Optional[str] = Unspecified # type: ignore[assignment] # 'Unspecified' default is for migrations from older versions column_type: ColumnType = ColumnType.COMMON - existing_server_default: str | None = None + existing_server_default: Optional[str] = None def __repr__(self): result_str = "TableReference("