Skip to content

Commit

Permalink
Improve Rewriter implementation
Browse files Browse the repository at this point in the history
Fixes sqlalchemy#1337 by:
* Fix the chaining of more than two rewriters
* Accept the chaining of a callable as well
  • Loading branch information
zrotceh committed Dec 11, 2023
1 parent 6827b4d commit 6ada3b1
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 8 deletions.
24 changes: 17 additions & 7 deletions alembic/autogenerate/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
Expand All @@ -23,6 +23,10 @@
from ..runtime.environment import _GetRevArg
from ..runtime.migration import MigrationContext

ProcessRevisionDirectiveFn = Callable[
[MigrationContext, _GetRevArg, List["MigrationScript"]], None
]


class Rewriter:
"""A helper object that allows easy 'rewriting' of ops streams.
Expand Down Expand Up @@ -52,15 +56,21 @@ class Rewriter:

_traverse = util.Dispatcher()

_chained: Optional[Rewriter] = None
_chained: Tuple[Union[ProcessRevisionDirectiveFn, Rewriter], ...] = ()

def __init__(self) -> None:
self.dispatch = util.Dispatcher()

def chain(self, other: Rewriter) -> Rewriter:
def chain(
self,
other: Union[
ProcessRevisionDirectiveFn,
Rewriter,
],
) -> Rewriter:
"""Produce a "chain" of this :class:`.Rewriter` to another.
This allows two rewriters to operate serially on a stream,
This allows two or more rewriters to operate serially on a stream,
e.g.::
writer1 = autogenerate.Rewriter()
Expand Down Expand Up @@ -89,7 +99,7 @@ def add_column_idx(context, revision, op):
"""
wr = self.__class__.__new__(self.__class__)
wr.__dict__.update(self.__dict__)
wr._chained = other
wr._chained += (other,)
return wr

def rewrites(
Expand Down Expand Up @@ -146,8 +156,8 @@ def __call__(
directives: List[MigrationScript],
) -> None:
self.process_revision_directives(context, revision, directives)
if self._chained:
self._chained(context, revision, directives)
for process_revision_directives in self._chained:
process_revision_directives(context, revision, directives)

@_traverse.dispatch_for(ops.MigrationScript)
def _traverse_script(
Expand Down
7 changes: 7 additions & 0 deletions docs/build/unreleased/1337.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. change::
:tags: bug, autogenerate
:tickets: 1337

Fixes `autogenerate.Rewriter` so that more than two instances could be
chained together correctly, and `process_revision_directives` callable
could also be chained.
15 changes: 14 additions & 1 deletion tests/test_script_production.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,11 @@ def add_column_idx(context, revision, op):
idx_op = ops.CreateIndexOp("ixt", op.table_name, [op.column.name])
return [op, idx_op]

def process_revision_directives(context, revision, generate_revisions):
generate_revisions[0].downgrade_ops = ops.DowngradeOps(
ops=[ops.DropColumnOp("t1", "x")]
)

directives = [
ops.MigrationScript(
util.rev_id(),
Expand All @@ -956,7 +961,8 @@ def add_column_idx(context, revision, op):
]

ctx, rev = mock.Mock(), mock.Mock()
writer1.chain(writer2)(ctx, rev, directives)
writer = writer1.chain(process_revision_directives).chain(writer2)
writer(ctx, rev, directives)

eq_(
autogenerate.render_python_code(directives[0].upgrade_ops),
Expand All @@ -970,6 +976,13 @@ def add_column_idx(context, revision, op):
" # ### end Alembic commands ###",
)

eq_(
autogenerate.render_python_code(directives[0].downgrade_ops),
"# ### commands auto generated by Alembic - please adjust! ###\n"
" op.drop_column('t1', 'x')\n"
" # ### end Alembic commands ###",
)

def test_no_needless_pass(self):
writer1 = autogenerate.Rewriter()

Expand Down

0 comments on commit 6ada3b1

Please sign in to comment.