Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 1.2.0 #75

Merged
merged 15 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions alembic_postgresql_enum/add_create_type_false.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,18 @@ def add_create_type_false(upgrade_ops: UpgradeOps):
if isinstance(operations_group, ModifyTableOps):
for operation in operations_group.ops:
if isinstance(operation, AddColumnOp):
column: Column = operation.column

inject_repr_into_enums(column)

inject_repr_into_enums(operation.column)
elif isinstance(operation, DropColumnOp):
column: Column = operation._reverse.column

inject_repr_into_enums(column)
assert operation._reverse is not None
inject_repr_into_enums(operation._reverse.column)

elif isinstance(operations_group, CreateTableOp):
for column in operations_group.columns:
if isinstance(column, Column):
inject_repr_into_enums(column)

elif isinstance(operations_group, DropTableOp):
assert operations_group._reverse is not None
for column in operations_group._reverse.columns:
if isinstance(column, Column):
inject_repr_into_enums(column)
1 change: 1 addition & 0 deletions alembic_postgresql_enum/add_postgres_using_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _postgres_using_alter_column(autogen_context: AutogenContext, op: ops.AlterC


def add_postgres_using_to_alter_operation(op: AlterColumnOp):
assert op.modify_type is not None
op.kw["postgresql_using"] = f"{op.column_name}::{op.modify_type.name}"
log.info("postgresql_using added to %r.%r alteration", op.table_name, op.column_name)
op.__class__ = PostgresUsingAlterColumnOp
Expand Down
16 changes: 16 additions & 0 deletions alembic_postgresql_enum/compare_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Iterable, Union

import alembic
Expand All @@ -16,6 +17,9 @@
from alembic_postgresql_enum.get_enum_data import get_defined_enums, get_declared_enums


log = logging.getLogger(f"alembic.{__name__}")


@alembic.autogenerate.comparators.dispatch_for("schema")
def compare_enums(
autogen_context: AutogenContext,
Expand All @@ -28,6 +32,12 @@ def compare_enums(
for each defined enum that has changed new entries when compared to its
declared version.
"""
if autogen_context.dialect.name != "postgresql":
log.warning(
f"This library only supports postgresql, but you are using {autogen_context.dialect.name}, skipping"
)
return

add_create_type_false(upgrade_ops)
add_postgres_using_to_text(upgrade_ops)

Expand All @@ -39,6 +49,12 @@ def compare_enums(
if isinstance(operations_group, CreateTableOp) and operations_group.schema not in schema_names:
schema_names.append(operations_group.schema)

assert (
autogen_context.dialect is not None
and autogen_context.dialect.default_schema_name is not None
and autogen_context.connection is not None
and autogen_context.metadata is not None
)
for schema in schema_names:
default_schema = autogen_context.dialect.default_schema_name
if schema is None:
Expand Down
3 changes: 2 additions & 1 deletion alembic_postgresql_enum/connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from contextlib import contextmanager
from typing import Iterator

import sqlalchemy


@contextmanager
def get_connection(operations) -> sqlalchemy.engine.Connection:
def get_connection(operations) -> Iterator[sqlalchemy.engine.Connection]:
"""
SQLAlchemy 2.0 changes the operation binding location; bridge function to support
both 1.x and 2.x.
Expand Down
11 changes: 6 additions & 5 deletions alembic_postgresql_enum/get_enum_data/declared_enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING
from enum import Enum
from typing import Tuple, Any, Set, Union, List, TYPE_CHECKING, cast

import sqlalchemy
from sqlalchemy import MetaData
Expand Down Expand Up @@ -93,16 +94,16 @@ def get_declared_enums(
if not column_type_is_enum(column_type):
continue

column_type_schema = column_type.schema or default_schema
column_type_schema = column_type.schema or default_schema # type: ignore[attr-defined]
if column_type_schema != schema:
continue

if column_type.name not in enum_name_to_values:
enum_name_to_values[column_type.name] = get_enum_values(column_type)
if column_type.name not in enum_name_to_values: # type: ignore[attr-defined]
enum_name_to_values[column_type.name] = get_enum_values(cast(sqlalchemy.Enum, column_type)) # type: ignore[attr-defined]

table_schema = table.schema or default_schema
column_default = get_column_default(connection, table_schema, table.name, column.name)
enum_name_to_table_references[column_type.name].add(
enum_name_to_table_references[column_type.name].add( # type: ignore[attr-defined]
TableReference(
table_schema=table_schema,
table_name=table.name,
Expand Down
6 changes: 3 additions & 3 deletions alembic_postgresql_enum/get_enum_data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def __repr__(self):
class TableReference:
table_name: str
column_name: str
table_schema: Optional[str] = Unspecified # 'Unspecified' default is for migrations from older versions
table_schema: Optional[str] = Unspecified # type: ignore[assignment] # 'Unspecified' default is for migrations from older versions
column_type: ColumnType = ColumnType.COMMON
existing_server_default: str = None
existing_server_default: Optional[str] = None

def __repr__(self):
result_str = "TableReference("
Expand All @@ -48,7 +48,7 @@ def table_name_with_schema(self):
prefix = f"{self.table_schema}."
else:
prefix = ""
return f"{prefix}{self.table_name}"
return f'{prefix}"{self.table_name}"'


EnumNamesToValues = Dict[str, Tuple[str, ...]]
Expand Down
1 change: 1 addition & 0 deletions alembic_postgresql_enum/operations/create_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def reverse(self):

@alembic.autogenerate.render.renderers.dispatch_for(CreateEnumOp)
def render_create_enum_op(autogen_context: AutogenContext, op: CreateEnumOp):
assert autogen_context.dialect is not None
if op.schema != autogen_context.dialect.default_schema_name:
return f"""
sa.Enum({', '.join(map(repr, op.enum_values))}, name='{op.name}', schema='{op.schema}').create(op.get_bind())
Expand Down
1 change: 1 addition & 0 deletions alembic_postgresql_enum/operations/drop_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def reverse(self):

@alembic.autogenerate.render.renderers.dispatch_for(DropEnumOp)
def render_drop_enum_op(autogen_context: AutogenContext, op: DropEnumOp):
assert autogen_context.dialect is not None
if op.schema != autogen_context.dialect.default_schema_name:
return f"""
sa.Enum({', '.join(map(repr, op.enum_values))}, name='{op.name}', schema='{op.schema}').drop(op.get_bind())
Expand Down
10 changes: 10 additions & 0 deletions alembic_postgresql_enum/operations/sync_enum_values.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import List, Tuple, Any, Iterable, TYPE_CHECKING

import alembic.autogenerate
Expand Down Expand Up @@ -31,6 +32,9 @@
from alembic_postgresql_enum.get_enum_data import TableReference, ColumnType


log = logging.getLogger(f"alembic.{__name__}")


@alembic.operations.base.Operations.register_operation("sync_enum_values")
class SyncEnumValuesOp(alembic.operations.ops.MigrateOperation):
operation_name = "change_enum_variants"
Expand Down Expand Up @@ -138,6 +142,12 @@ def sync_enum_values(
]
If there was server default with old_name it will be renamed accordingly
"""
if operations.migration_context.dialect.name != "postgresql":
log.warning(
f"This library only supports postgresql, but you are using {operations.migration_context.dialect.name}, skipping"
)
return

enum_values_to_rename = list(enum_values_to_rename)

with get_connection(operations) as connection:
Expand Down
Empty file.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "alembic-postgresql-enum"
version = "1.1.2"
version = "1.2.0"
description = "Alembic autogenerate support for creation, alteration and deletion of enums"
authors = ["RustyGuard"]
license = "MIT"
Expand Down
4 changes: 4 additions & 0 deletions tests/base/render_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def compare_and_run(
*,
expected_upgrade: str,
expected_downgrade: str,
disable_running: bool = False,
):
"""Compares generated migration script is equal to expected_upgrade and expected_downgrade, then runs it"""
migration_context = create_migration_context(connection, target_schema)
Expand All @@ -39,6 +40,9 @@ def compare_and_run(
assert upgrade_code == expected_upgrade, f"Got:\n{upgrade_code!r}\nExpected:\n{expected_upgrade!r}"
assert downgrade_code == expected_downgrade, f"Got:\n{downgrade_code!r}\nExpected:\n{expected_downgrade!r}"

if disable_running:
return

exec(
upgrade_code,
{ # todo Use imports from template_args
Expand Down
11 changes: 11 additions & 0 deletions tests/base/run_migration_test_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,21 @@


class CompareAndRunTestCase(ABC):
"""
Base class for all tests that expect specific alembic generated code
"""

disable_running = False

@abstractmethod
def get_database_schema(self) -> MetaData: ...

@abstractmethod
def get_target_schema(self) -> MetaData: ...

def insert_migration_data(self, connection: "Connection", database_schema: MetaData) -> None:
pass

@abstractmethod
def get_expected_upgrade(self) -> str: ...

Expand All @@ -26,10 +35,12 @@ def test_run(self, connection: "Connection"):
target_schema = self.get_target_schema()

database_schema.create_all(connection)
self.insert_migration_data(connection, database_schema)

compare_and_run(
connection,
target_schema,
expected_upgrade=self.get_expected_upgrade(),
expected_downgrade=self.get_expected_downgrade(),
disable_running=self.disable_running,
)
89 changes: 47 additions & 42 deletions tests/test_alter_column/test_text_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,61 @@
from sqlalchemy import MetaData, Table, Column, TEXT, insert
from sqlalchemy.dialects import postgresql

from tests.base.run_migration_test_abc import CompareAndRunTestCase

if TYPE_CHECKING:
from sqlalchemy import Connection

from tests.base.render_and_run import compare_and_run


class NewEnum(Enum):
A = "a"
B = "b"
C = "c"


def test_text_column(connection: "Connection"):
database_schema = MetaData()
a_table = Table("a", database_schema, Column("value", TEXT))
database_schema.create_all(connection)
connection.execute(
insert(a_table).values(
[
{"value": NewEnum.A.name},
{"value": NewEnum.B.name},
{"value": NewEnum.B.name},
{"value": NewEnum.C.name},
]
class TestTextColumn(CompareAndRunTestCase):
def get_database_schema(self) -> MetaData:
database_schema = MetaData()
Table("a", database_schema, Column("value", TEXT))
return database_schema

def get_target_schema(self) -> MetaData:
target_schema = MetaData()
Table("a", target_schema, Column("value", postgresql.ENUM(NewEnum)))
return target_schema

def insert_migration_data(self, connection: "Connection", database_schema: MetaData) -> None:
a_table = database_schema.tables["a"]
connection.execute(
insert(a_table).values(
[
{"value": NewEnum.A.name},
{"value": NewEnum.B.name},
{"value": NewEnum.B.name},
{"value": NewEnum.C.name},
]
)
)
)

target_schema = MetaData()
Table("a", target_schema, Column("value", postgresql.ENUM(NewEnum)))

compare_and_run(
connection,
target_schema,
expected_upgrade=f"""
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum('A', 'B', 'C', name='newenum').create(op.get_bind())
op.alter_column('a', 'value',
existing_type=sa.TEXT(),
type_=postgresql.ENUM('A', 'B', 'C', name='newenum'),
existing_nullable=True,
postgresql_using='value::newenum')
# ### end Alembic commands ###
""",
expected_downgrade=f"""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('a', 'value',
existing_type=postgresql.ENUM('A', 'B', 'C', name='newenum'),
type_=sa.TEXT(),
existing_nullable=True)
sa.Enum('A', 'B', 'C', name='newenum').drop(op.get_bind())
# ### end Alembic commands ###
""",
)

def get_expected_upgrade(self) -> str:
return """
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum('A', 'B', 'C', name='newenum').create(op.get_bind())
op.alter_column('a', 'value',
existing_type=sa.TEXT(),
type_=postgresql.ENUM('A', 'B', 'C', name='newenum'),
existing_nullable=True,
postgresql_using='value::newenum')
# ### end Alembic commands ###
"""

def get_expected_downgrade(self) -> str:
return """
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('a', 'value',
existing_type=postgresql.ENUM('A', 'B', 'C', name='newenum'),
type_=sa.TEXT(),
existing_nullable=True)
sa.Enum('A', 'B', 'C', name='newenum').drop(op.get_bind())
# ### end Alembic commands ###
"""
Loading
Loading