Skip to content

Commit

Permalink
Merge pull request #71 from Pogchamp-company/feature/-/mypy
Browse files Browse the repository at this point in the history
Add py.typed flag and add more type hints
  • Loading branch information
RustyGuard authored Mar 31, 2024
2 parents 8b10e47 + 074cfaf commit 9d9d8a8
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 15 deletions.
11 changes: 4 additions & 7 deletions alembic_postgresql_enum/add_create_type_false.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,18 @@ def add_create_type_false(upgrade_ops: UpgradeOps):
if isinstance(operations_group, ModifyTableOps):
for operation in operations_group.ops:
if isinstance(operation, AddColumnOp):
column: Column = operation.column

inject_repr_into_enums(column)

inject_repr_into_enums(operation.column)
elif isinstance(operation, DropColumnOp):
column: Column = operation._reverse.column

inject_repr_into_enums(column)
assert operation._reverse is not None
inject_repr_into_enums(operation._reverse.column)

elif isinstance(operations_group, CreateTableOp):
for column in operations_group.columns:
if isinstance(column, Column):
inject_repr_into_enums(column)

elif isinstance(operations_group, DropTableOp):
assert operations_group._reverse is not None
for column in operations_group._reverse.columns:
if isinstance(column, Column):
inject_repr_into_enums(column)
1 change: 1 addition & 0 deletions alembic_postgresql_enum/add_postgres_using_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _postgres_using_alter_column(autogen_context: AutogenContext, op: ops.AlterC


def add_postgres_using_to_alter_operation(op: AlterColumnOp):
assert op.modify_type is not None
op.kw["postgresql_using"] = f"{op.column_name}::{op.modify_type.name}"
log.info("postgresql_using added to %r.%r alteration", op.table_name, op.column_name)
op.__class__ = PostgresUsingAlterColumnOp
Expand Down
6 changes: 6 additions & 0 deletions alembic_postgresql_enum/compare_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def compare_enums(
if isinstance(operations_group, CreateTableOp) and operations_group.schema not in schema_names:
schema_names.append(operations_group.schema)

assert (
autogen_context.dialect is not None
and autogen_context.dialect.default_schema_name is not None
and autogen_context.connection is not None
and autogen_context.metadata is not None
)
for schema in schema_names:
default_schema = autogen_context.dialect.default_schema_name
if schema is None:
Expand Down
3 changes: 2 additions & 1 deletion alembic_postgresql_enum/connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from contextlib import contextmanager
from typing import Iterator

import sqlalchemy


@contextmanager
def get_connection(operations) -> sqlalchemy.engine.Connection:
def get_connection(operations) -> Iterator[sqlalchemy.engine.Connection]:
"""
SQLAlchemy 2.0 changes the operation binding location; bridge function to support
both 1.x and 2.x.
Expand Down
11 changes: 6 additions & 5 deletions alembic_postgresql_enum/get_enum_data/declared_enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING
from enum import Enum
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast

import sqlalchemy
from sqlalchemy import MetaData
Expand Down Expand Up @@ -93,16 +94,16 @@ def get_declared_enums(
if not column_type_is_enum(column_type):
continue

column_type_schema = column_type.schema or default_schema
column_type_schema = column_type.schema or default_schema # type: ignore[attr-defined]
if column_type_schema != schema:
continue

if column_type.name not in enum_name_to_values:
enum_name_to_values[column_type.name] = get_enum_values(column_type)
if column_type.name not in enum_name_to_values: # type: ignore[attr-defined]
enum_name_to_values[column_type.name] = get_enum_values(cast(sqlalchemy.Enum, column_type)) # type: ignore[attr-defined]

table_schema = table.schema or default_schema
column_default = get_column_default(connection, table_schema, table.name, column.name)
enum_name_to_table_references[column_type.name].add(
enum_name_to_table_references[column_type.name].add( # type: ignore[attr-defined]
TableReference(
table_schema=table_schema,
table_name=table.name,
Expand Down
4 changes: 2 additions & 2 deletions alembic_postgresql_enum/get_enum_data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def __repr__(self):
class TableReference:
table_name: str
column_name: str
table_schema: Optional[str] = Unspecified # 'Unspecified' default is for migrations from older versions
table_schema: Optional[str] = Unspecified # type: ignore[assignment] # 'Unspecified' default is for migrations from older versions
column_type: ColumnType = ColumnType.COMMON
existing_server_default: str = None
existing_server_default: Optional[str] = None

def __repr__(self):
result_str = "TableReference("
Expand Down
1 change: 1 addition & 0 deletions alembic_postgresql_enum/operations/create_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def reverse(self):

@alembic.autogenerate.render.renderers.dispatch_for(CreateEnumOp)
def render_create_enum_op(autogen_context: AutogenContext, op: CreateEnumOp):
assert autogen_context.dialect is not None
if op.schema != autogen_context.dialect.default_schema_name:
return f"""
sa.Enum({', '.join(map(repr, op.enum_values))}, name='{op.name}', schema='{op.schema}').create(op.get_bind())
Expand Down
1 change: 1 addition & 0 deletions alembic_postgresql_enum/operations/drop_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def reverse(self):

@alembic.autogenerate.render.renderers.dispatch_for(DropEnumOp)
def render_drop_enum_op(autogen_context: AutogenContext, op: DropEnumOp):
assert autogen_context.dialect is not None
if op.schema != autogen_context.dialect.default_schema_name:
return f"""
sa.Enum({', '.join(map(repr, op.enum_values))}, name='{op.name}', schema='{op.schema}').drop(op.get_bind())
Expand Down
Empty file.

0 comments on commit 9d9d8a8

Please sign in to comment.