diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index 13d025b2..7282487b 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -2,7 +2,6 @@ import contextlib from typing import Any -from typing import Callable from typing import Dict from typing import Iterator from typing import List @@ -35,6 +34,7 @@ from ..operations.ops import UpgradeOps from ..runtime.environment import NameFilterParentNames from ..runtime.environment import NameFilterType + from ..runtime.environment import ProcessRevisionDirectiveFn from ..runtime.environment import RenderItemFn from ..runtime.migration import MigrationContext from ..script.base import Script @@ -510,13 +510,16 @@ class RevisionContext: file generation operation.""" generated_revisions: List[MigrationScript] + process_revision_directives: Optional[ProcessRevisionDirectiveFn] def __init__( self, config: Config, script_directory: ScriptDirectory, command_args: Dict[str, Any], - process_revision_directives: Optional[Callable] = None, + process_revision_directives: Optional[ + ProcessRevisionDirectiveFn + ] = None, ) -> None: self.config = config self.script_directory = script_directory diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 1f4bcf89..9c84cd6c 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -245,6 +245,11 @@ def _add_table(autogen_context: AutogenContext, op: ops.CreateTableOp) -> str: comment = table.comment if comment: text += ",\ncomment=%r" % _ident(comment) + + info = table.info + if info: + text += f",\ninfo={info!r}" + for k in sorted(op.kw): text += ",\n%s=%r" % (k.replace(" ", "_"), op.kw[k]) diff --git a/alembic/autogenerate/rewriter.py b/alembic/autogenerate/rewriter.py index 4209c321..68a93dd0 100644 --- a/alembic/autogenerate/rewriter.py +++ b/alembic/autogenerate/rewriter.py @@ -9,19 +9,19 @@ from typing import TYPE_CHECKING from typing import Union -from alembic import util -from alembic.operations import ops +from .. import util +from ..operations import ops if TYPE_CHECKING: - from alembic.operations.ops import AddColumnOp - from alembic.operations.ops import AlterColumnOp - from alembic.operations.ops import CreateTableOp - from alembic.operations.ops import MigrateOperation - from alembic.operations.ops import MigrationScript - from alembic.operations.ops import ModifyTableOps - from alembic.operations.ops import OpContainer - from alembic.runtime.migration import MigrationContext - from alembic.script.revision import Revision + from ..operations.ops import AddColumnOp + from ..operations.ops import AlterColumnOp + from ..operations.ops import CreateTableOp + from ..operations.ops import MigrateOperation + from ..operations.ops import MigrationScript + from ..operations.ops import ModifyTableOps + from ..operations.ops import OpContainer + from ..runtime.environment import _GetRevArg + from ..runtime.migration import MigrationContext class Rewriter: @@ -119,7 +119,7 @@ def add_column_nullable(context, revision, op): def _rewrite( self, context: MigrationContext, - revision: Revision, + revision: _GetRevArg, directive: MigrateOperation, ) -> Iterator[MigrateOperation]: try: @@ -142,7 +142,7 @@ def _rewrite( def __call__( self, context: MigrationContext, - revision: Revision, + revision: _GetRevArg, directives: List[MigrationScript], ) -> None: self.process_revision_directives(context, revision, directives) @@ -153,7 +153,7 @@ def __call__( def _traverse_script( self, context: MigrationContext, - revision: Revision, + revision: _GetRevArg, directive: MigrationScript, ) -> None: upgrade_ops_list = [] @@ -180,7 +180,7 @@ def _traverse_script( def _traverse_op_container( self, context: MigrationContext, - revision: Revision, + revision: _GetRevArg, directive: OpContainer, ) -> None: self._traverse_list(context, revision, directive.ops) @@ -189,7 +189,7 @@ def _traverse_op_container( def _traverse_any_directive( self, context: MigrationContext, - revision: Revision, + revision: _GetRevArg, directive: MigrateOperation, ) -> None: pass @@ -197,7 +197,7 @@ def _traverse_any_directive( def _traverse_for( self, context: MigrationContext, - revision: Revision, + revision: _GetRevArg, directive: MigrateOperation, ) -> Any: directives = list(self._rewrite(context, revision, directive)) @@ -209,7 +209,7 @@ def _traverse_for( def _traverse_list( self, context: MigrationContext, - revision: Revision, + revision: _GetRevArg, directives: Any, ) -> None: dest = [] @@ -221,7 +221,7 @@ def _traverse_list( def process_revision_directives( self, context: MigrationContext, - revision: Revision, + revision: _GetRevArg, directives: List[MigrationScript], ) -> None: self._traverse_list(context, revision, directives) diff --git a/alembic/context.pyi b/alembic/context.pyi index 85e0cf75..f37f2461 100644 --- a/alembic/context.pyi +++ b/alembic/context.pyi @@ -7,6 +7,7 @@ from typing import Callable from typing import Collection from typing import ContextManager from typing import Dict +from typing import Iterable from typing import List from typing import Literal from typing import Mapping @@ -143,7 +144,12 @@ def configure( include_schemas: bool = False, process_revision_directives: Optional[ Callable[ - [MigrationContext, Tuple[str, str], List[MigrationScript]], None + [ + MigrationContext, + Union[str, Iterable[Optional[str]], Iterable[str]], + List[MigrationScript], + ], + None, ] ] = None, compare_type: Union[ diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index a1c0e1b0..7640f563 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -23,6 +23,7 @@ from .migration import MigrationContext from .. import util from ..operations import Operations +from ..script.revision import _GetRevArg if TYPE_CHECKING: from sqlalchemy.engine import URL @@ -42,7 +43,7 @@ _RevNumber = Optional[Union[str, Tuple[str, ...]]] ProcessRevisionDirectiveFn = Callable[ - [MigrationContext, Tuple[str, str], List["MigrationScript"]], None + [MigrationContext, _GetRevArg, List["MigrationScript"]], None ] RenderItemFn = Callable[ diff --git a/alembic/script/revision.py b/alembic/script/revision.py index aa0e9040..03502644 100644 --- a/alembic/script/revision.py +++ b/alembic/script/revision.py @@ -32,14 +32,8 @@ _RevIdType = Union[str, List[str], Tuple[str, ...]] _GetRevArg = Union[ str, - List[Optional[str]], - Tuple[Optional[str], ...], - FrozenSet[Optional[str]], - Set[Optional[str]], - List[str], - Tuple[str, ...], - FrozenSet[str], - Set[str], + Iterable[Optional[str]], + Iterable[str], ] _RevisionIdentifierType = Union[str, Tuple[str, ...], None] _RevisionOrStr = Union["Revision", str] @@ -738,7 +732,7 @@ def _shares_lineage( ) def _resolve_revision_number( - self, id_: Optional[str] + self, id_: Optional[_GetRevArg] ) -> Tuple[Tuple[str, ...], Optional[str]]: branch_label: Optional[str] if isinstance(id_, str) and "@" in id_: diff --git a/docs/build/unreleased/1329.rst b/docs/build/unreleased/1329.rst new file mode 100644 index 00000000..b6065d95 --- /dev/null +++ b/docs/build/unreleased/1329.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, autogenerate, regression + :tickets: 1329 + + Fixed regression caused by :ticket:`879` released in 1.7.0 where the + ".info" dictionary of ``Table`` would not render in autogenerate create + table statements. This can be useful for custom create table DDL rendering + schemes so it is restored. diff --git a/tests/test_autogen_render.py b/tests/test_autogen_render.py index 5c200b1f..88aa978c 100644 --- a/tests/test_autogen_render.py +++ b/tests/test_autogen_render.py @@ -1983,6 +1983,27 @@ def test_render_table_with_comment(self): ")", ) + def test_render_table_with_info(self): + m = MetaData() + t = Table( + "test", + m, + Column("id", Integer, primary_key=True), + Column("q", Integer, ForeignKey("address.id")), + info={"oracle_partition": "PARTITION BY ..."}, + ) + op_obj = ops.CreateTableOp.from_table(t) + eq_ignore_whitespace( + autogenerate.render_op_text(self.autogen_context, op_obj), + "op.create_table('test'," + "sa.Column('id', sa.Integer(), nullable=False)," + "sa.Column('q', sa.Integer(), nullable=True)," + "sa.ForeignKeyConstraint(['q'], ['address.id'], )," + "sa.PrimaryKeyConstraint('id')," + "info={'oracle_partition': 'PARTITION BY ...'}" + ")", + ) + def test_render_add_column_with_comment(self): op_obj = ops.AddColumnOp( "foo", Column("x", Integer, comment="This is a Column")