Skip to content

Commit

Permalink
fix: (Huge speed optimization) Avoid the use of the high-level alembi…
Browse files Browse the repository at this point in the history
…c command interface in most cases.

In particular, the `alembic.command.<x>` (where <x> is something like
upgrade, downgrade, or head(s)). These commands are (seemingly) intended
to be used as one-off commands because they internally create `ScriptDirectory`
objects.

Given that we execute numerous individual commands per test, that creation
step adds up leads to (with any appreciable history length) test
runtime. An `alembic upgrade head` which takes 20s can take up to 5m
before this change.

This PR switches to internally producing a `ScriptDirectory` once, and
reimplementing some of the internals of the `alembic.command` interface
becuase there isn't a public way of providing a script to those
functions.
  • Loading branch information
DanCardin committed Feb 8, 2022
1 parent 3404981 commit d616ffa
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 15 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## [Unreleased](https://github.com/schireson/pytest-alembic/compare/v0.7.0...HEAD) (2022-02-08)

### Fixes

* (Huge speed optimization) Avoid the use of the high-level alembic command interface in most cases. 1ae311f


## [v0.7.0](https://github.com/schireson/pytest-alembic/compare/v0.6.1...v0.7.0) (2021-12-21)

### ⚠ BREAKING CHANGE
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytest-alembic"
version = "0.7.0"
version = "0.8.0"
description = "A pytest plugin for verifying alembic migrations."
authors = [
"Dan Cardin <[email protected]>",
Expand Down
41 changes: 40 additions & 1 deletion src/pytest_alembic/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import alembic
import alembic.config
from alembic.runtime.environment import EnvironmentContext
from alembic.script.base import ScriptDirectory
from sqlalchemy import MetaData, Table
from sqlalchemy.engine import Connection

Expand All @@ -17,12 +19,18 @@ class CommandExecutor:
alembic_config: alembic.config.Config
stdout: StringIO
stream_position: int
script: ScriptDirectory

@classmethod
def from_config(cls, config: Config):
stdout = StringIO()
alembic_config = config.make_alembic_config(stdout)
return cls(alembic_config=alembic_config, stdout=stdout, stream_position=0)
return cls(
alembic_config=alembic_config,
stdout=stdout,
stream_position=0,
script=ScriptDirectory.from_config(alembic_config),
)

def configure(self, **kwargs):
for key, value in kwargs.items():
Expand All @@ -44,6 +52,37 @@ def run_command(self, command, *args, **kwargs):
self.stdout.seek(self.stream_position)
return self.stdout.readlines()

def heads(self):
return [rev.revision for rev in self.script.get_revisions("heads")]

def upgrade(self, revision):
"""Upgrade to the given `revision`."""

def upgrade(rev, _):
return self.script._upgrade_revs(revision, rev)

self._run_env(upgrade, revision)

def downgrade(self, revision):
"""Downgrade to the given `revision`."""

def downgrade(rev, _):
return self.script._downgrade_revs(revision, rev)

self._run_env(downgrade, revision)

def _run_env(self, fn, revision=None):
"""Execute the migrations' env.py, given some function to execute."""
dont_mutate = revision is None
with EnvironmentContext(
self.alembic_config,
self.script,
fn=fn,
destination_rev=revision,
dont_mutate=dont_mutate,
):
self.script.run_env()


@dataclass
class ConnectionExecutor:
Expand Down
42 changes: 29 additions & 13 deletions src/pytest_alembic/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

import alembic.command
import alembic.migration
from sqlalchemy.engine import Engine

from pytest_alembic.config import Config
Expand Down Expand Up @@ -35,6 +37,7 @@ class MigrationContext:
revision_data: RevisionData
connection_executor: ConnectionExecutor
config: Config
history: AlembicHistory
connection: Engine = None

@classmethod
Expand All @@ -45,31 +48,44 @@ def from_config(
connection_executor: ConnectionExecutor,
connection: Engine,
):
# XXX: Perhaps we can avoid this parsing and `AlembicHistory` entirely
# and use whatever alembic uses internally. All `raw_command`
# invocations will be slow; although at least this specific one
# only happens once per test, so it's less important to optimize.
raw_history = command_executor.run_command("history")
history = AlembicHistory.parse(tuple(raw_history))

return cls(
command_executor=command_executor,
revision_data=RevisionData.from_config(config),
connection_executor=connection_executor,
config=config,
history=history,
connection=connection,
)

@property
def history(self) -> AlembicHistory:
"""Get the revision history."""
raw_history = self.command_executor.run_command("history")
return AlembicHistory.parse(tuple(raw_history))

@property
def heads(self) -> List[str]:
"""Get the list of revision heads."""
return self.command_executor.run_command("heads")
"""Get the list of revision heads.
Result is cached for the lifetime of the `MigrationContext`.
"""
return self.command_executor.heads()

@property
def current(self) -> str:
"""Get the list of revision heads."""
current = self.command_executor.run_command("current")

def get_current(conn):
context = alembic.migration.MigrationContext.configure(conn)
heads = context.get_current_heads()
if heads:
return heads[0]
return None

current = run_connection_task(self.connection, get_current)
if current:
return current[0].strip().split(" ")[0]
return current
return "base"

def generate_revision(self, process_revision_directives=None, **kwargs):
Expand Down Expand Up @@ -102,7 +118,7 @@ def managed_upgrade(self, dest_revision):
before_upgrade_data = list(self.revision_data.get_before(next_revision))
self.insert_into(data=before_upgrade_data, revision=current_revision, table=None)

self.raw_command("upgrade", next_revision)
self.command_executor.upgrade(next_revision)

at_upgrade_data = list(self.revision_data.get_at(next_revision))
self.insert_into(data=at_upgrade_data, revision=next_revision, table=None)
Expand Down Expand Up @@ -136,13 +152,13 @@ def migrate_down_before(self, revision):
def migrate_down_to(self, revision):
"""Migrate down to, and including the given `revision`."""
self.history.validate_revision(revision)
self.raw_command("downgrade", revision)
self.command_executor.downgrade(revision)
return revision

def migrate_down_one(self):
"""Migrate down by exactly one revision."""
previous_revision = self.history.previous_revision(self.current)
self.raw_command("downgrade", previous_revision)
self.command_executor.downgrade(previous_revision)
return previous_revision

def roundtrip_next_revision(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def _test_downgrade_leaves_no_trace(connection, alembic_runner: MigrationContext
# So we need to proceed by one.
alembic_runner.migrate_up_to(revision)

if hasattr(connection, "commit"):
connection.commit()


def check_revision_cycle(alembic_runner, connection, original_revision):
migration_context = alembic.migration.MigrationContext.configure(connection)
Expand Down

0 comments on commit d616ffa

Please sign in to comment.