Skip to content

Commit

Permalink
fix: Issue with runtime version_table_schema option.
Browse files Browse the repository at this point in the history
`version_table_schema` appears to be (one of?) the only options that could
meaningfully affect the runtime behavior as-was, because it changes
where an un-configured (through env.py) MigrationContext instance would
point at an incorrect bit of data.

Now we run the process through env.py to ensure that setting it picked
up.

In order to attempt to avoid excess calls through the env.py, the
defaults tests now internally avoid making calls to `self.current`
wherever possible.
  • Loading branch information
DanCardin committed Jun 27, 2023
1 parent d1cbea3 commit 8e417a3
Show file tree
Hide file tree
Showing 13 changed files with 184 additions and 27 deletions.
36 changes: 36 additions & 0 deletions examples/test_version_table_schema/alembic.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[alembic]
script_location = migrations

[loggers]
keys = root,sqlalchemy,alembic

[handlers]
keys = console

[formatters]
keys = generic

[logger_root]
level = WARN
handlers = console
qualname =

[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine

[logger_alembic]
level = INFO
handlers =
qualname = alembic

[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic

[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
3 changes: 3 additions & 0 deletions examples/test_version_table_schema/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pytest_mock_resources import create_postgres_fixture, Statements

alembic_engine = create_postgres_fixture(Statements("CREATE SCHEMA version_table_schema"))
29 changes: 29 additions & 0 deletions examples/test_version_table_schema/migrations/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from logging.config import fileConfig

from alembic import context
from sqlalchemy import engine_from_config, pool

from models import Base

fileConfig(context.config.config_file_name)
target_metadata = Base.metadata


connectable = context.config.attributes.get("connection", None)

if connectable is None:
connectable = engine_from_config(
context.config.get_section(context.config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)

with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
version_table_schema="version_table_schema",
)

with context.begin_transaction():
context.run_migrations()
24 changes: 24 additions & 0 deletions examples/test_version_table_schema/migrations/script.py.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""${message}

Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}

"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}

# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}


def upgrade():
${upgrades if upgrades else "pass"}


def downgrade():
${downgrades if downgrades else "pass"}
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import sqlalchemy as sa
from alembic import op

revision = "aaaaaaaaaaaa"
down_revision = None
branch_labels = None
depends_on = None


def upgrade():
op.create_table(
"foo",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("(CURRENT_TIMESTAMP)"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)


def downgrade():
op.drop_table("foo")
17 changes: 17 additions & 0 deletions examples/test_version_table_schema/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import sqlalchemy
from sqlalchemy import Column, types
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()


class CreatedAt(Base):
__tablename__ = "foo"

id = Column(types.Integer(), autoincrement=True, primary_key=True)

created_at = sqlalchemy.Column(
sqlalchemy.types.DateTime(timezone=True),
server_default=sqlalchemy.text("CURRENT_TIMESTAMP"),
nullable=False,
)
3 changes: 3 additions & 0 deletions examples/test_version_table_schema/setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[tool:pytest]
pytest_alembic_exclude = single_head_revision
pytest_alembic_include_experimental = downgrade_leaves_no_trace,all_models_register_on_metadata
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytest-alembic"
version = "0.10.5"
version = "0.10.6"
description = "A pytest plugin for verifying alembic migrations."
authors = [
"Dan Cardin <[email protected]>",
Expand Down
4 changes: 4 additions & 0 deletions src/pytest_alembic/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def configure(self, **kwargs):
for key, value in kwargs.items():
self.alembic_config.attributes[key] = value

def execute_fn(self, fn):
with EnvironmentContext(self.alembic_config, self.script, fn=fn):
self.script.run_env()

def run_command(self, command, *args, **kwargs):
self.stream_position = self.stdout.tell()

Expand Down
55 changes: 33 additions & 22 deletions src/pytest_alembic/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,17 @@ def heads(self) -> List[str]:
@property
def current(self) -> str:
"""Get the list of revision heads."""
current = "base"

def get_current(connection):
context = alembic.migration.MigrationContext.configure(connection)
heads = context.get_current_heads()
if heads:
return heads[0]
return None
def get_current(rev, _):
nonlocal current
if rev:
current = rev[0]

return []

self.command_executor.execute_fn(get_current)

current = self.connection_executor.run_task(get_current)
if current:
return current
return "base"
Expand Down Expand Up @@ -138,9 +140,11 @@ def raw_command(self, *args, **kwargs):
"""Execute a raw alembic command."""
return self.command_executor.run_command(*args, **kwargs)

def managed_upgrade(self, dest_revision):
def managed_upgrade(self, dest_revision, *, current=None, return_current=True):
"""Perform an upgrade one migration at a time, inserting static data at the given points."""
current = self.current
if current is None:
current = self.current

for current_revision, next_revision in self.history.revision_window(current, dest_revision):
before_upgrade_data = self.revision_data.get_before(next_revision)
self.insert_into(data=before_upgrade_data, revision=current_revision, table=None)
Expand All @@ -153,12 +157,16 @@ def managed_upgrade(self, dest_revision):
at_upgrade_data = self.revision_data.get_at(next_revision)
self.insert_into(data=at_upgrade_data, revision=next_revision, table=None)

current = self.current
return current
if return_current:
current = self.current
return current
return None

def managed_downgrade(self, dest_revision):
def managed_downgrade(self, dest_revision, *, current=None, return_current=True):
"""Perform an downgrade, one migration at a time."""
current = self.current
if current is None:
current = self.current

for next_revision, current_revision in reversed(
self.history.revision_window(dest_revision, current)
):
Expand All @@ -173,23 +181,25 @@ def managed_downgrade(self, dest_revision):
else:
raise

current = self.current
return current
if return_current:
current = self.current
return current
return None

def migrate_up_before(self, revision):
"""Migrate up to, but not including the given `revision`."""
preceeding_revision = self.history.previous_revision(revision)
return self.managed_upgrade(preceeding_revision)

def migrate_up_to(self, revision):
def migrate_up_to(self, revision, *, return_current: bool = True):
"""Migrate up to, and including the given `revision`."""
return self.managed_upgrade(revision)
return self.managed_upgrade(revision, return_current=return_current)

def migrate_up_one(self):
"""Migrate up by exactly one revision."""
current = self.current
next_revision = self.history.next_revision(current)
new_revision = self.managed_upgrade(next_revision)
new_revision = self.managed_upgrade(next_revision, current=current)
if current == new_revision:
return None
return new_revision
Expand All @@ -199,16 +209,17 @@ def migrate_down_before(self, revision):
next_revision = self.history.next_revision(revision)
return self.migrate_down_to(next_revision)

def migrate_down_to(self, revision):
def migrate_down_to(self, revision, *, return_current: bool = True):
"""Migrate down to, and including the given `revision`."""
self.history.validate_revision(revision)
self.managed_downgrade(revision)
self.managed_downgrade(revision, return_current=return_current)
return revision

def migrate_down_one(self):
"""Migrate down by exactly one revision."""
previous_revision = self.history.previous_revision(self.current)
self.managed_downgrade(previous_revision)
current = self.current
previous_revision = self.history.previous_revision(current)
self.managed_downgrade(previous_revision, current=current)
return previous_revision

def roundtrip_next_revision(self):
Expand Down
8 changes: 4 additions & 4 deletions src/pytest_alembic/tests/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_single_head_revision(alembic_runner):
def test_upgrade(alembic_runner):
"""Assert that the revision history can be run through from base to head."""
try:
alembic_runner.migrate_up_to("heads")
alembic_runner.migrate_up_to("heads", return_current=False)
except RuntimeError as e:
message = (
"Failed to upgrade to the head revision. This means the historical chain from an "
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_up_down_consistency(alembic_runner):
"""
for revision in alembic_runner.history.revisions:
try:
alembic_runner.migrate_up_to(revision)
alembic_runner.migrate_up_to(revision, return_current=False)
except Exception as e:
message = "Failed to upgrade through each revision individually."
raise AlembicTestFailure(
Expand All @@ -128,7 +128,7 @@ def test_up_down_consistency(alembic_runner):
break

try:
alembic_runner.migrate_down_to(revision)
alembic_runner.migrate_down_to(revision, return_current=False)
except NotImplementedError:
# In the event of a `NotImplementedError`, we should have the same semantics,
# as-if `minimum_downgrade_revision` was specified, but we'll emit a warning
Expand All @@ -148,7 +148,7 @@ def test_up_down_consistency(alembic_runner):

for revision in reversed(down_revisions):
try:
alembic_runner.migrate_up_to(revision)
alembic_runner.migrate_up_to(revision, return_current=False)
except Exception as e:
message = (
"Failed to upgrade through each revision individually after performing a "
Expand Down
5 changes: 5 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,8 @@ def test_skip_revision(pytester):
def test_pytest_alembic_tests_path(pytester):
"""Assert the pytest_alembic_tests_path can be overridden."""
run_pytest(pytester, passed=4, args=["-vv", "--test-alembic", "tests_"])


def test_version_table_schema(pytester):
"""Assert the setting the version_table_schema option functions correctly."""
run_pytest(pytester, passed=5)

0 comments on commit 8e417a3

Please sign in to comment.