Skip to content

Commit

Permalink
add data migration test
Browse files Browse the repository at this point in the history
  • Loading branch information
Allie Crevier committed Oct 21, 2020
1 parent 269845b commit e43fc7c
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def upgrade():
FROM replies, users
WHERE journalist_id=users.uuid;
""")
assert not replies_with_incorrect_associations


def downgrade():
Expand Down
Empty file added tests/migrations/__init__.py
Empty file.
141 changes: 141 additions & 0 deletions tests/migrations/test_a4bf1f58ce69.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# -*- coding: utf-8 -*-

import os
import random
import subprocess

from securedrop_client import db
from securedrop_client.db import Reply, User

from .utils import add_reply, add_source, add_user

random.seed("=^..^=..^=..^=")


class UpgradeTester:
"""
Verify that upgrading to the target migration results in the replacement of uuid with id of the
user in the replies table's journalist_id column.
"""

NUM_USERS = 20
NUM_SOURCES = 20
NUM_REPLIES = 40

def __init__(self, homedir):
subprocess.check_call(["sqlite3", os.path.join(homedir, "svs.sqlite"), ".databases"])
self.session = db.make_session_maker(homedir)()

def load_data(self):
"""
Load data that has the bug where user.uuid is stored in replies.journalist_id.
"""
for _ in range(self.NUM_SOURCES):
add_source(self.session)

self.session.commit()

for i in range(self.NUM_USERS):
if i == 0:
# As of this migration, the server tells the client that the associated journalist
# of a reply has been deleted by returning "deleted" as the uuid of the associated
# journalist. This gets stored as the jouranlist_id in the replies table.
#
# Make sure to test this case as well.
add_user(self.session, "deleted")
source_id = random.randint(1, self.NUM_SOURCES - 1)
add_reply(self.session, "deleted", source_id)
else:
add_user(self.session)

self.session.commit()

# Add replies from randomly-selected journalists to a randomly-selected sources
for _ in range(1, self.NUM_REPLIES):
journalist_id = random.randint(1, self.NUM_USERS - 1)
journalist = self.session.query(User).filter_by(id=journalist_id).one()
source_id = random.randint(1, self.NUM_SOURCES - 1)
add_reply(self.session, journalist.uuid, source_id)

self.session.commit()
self.session.close()

def check_upgrade(self):
"""
Make sure each reply in the replies table has the correct journalist_id stored for the
associated journalist by making sure a User account exists with that journalist id.
"""
replies = self.session.query(Reply).all()
assert len(replies)

for reply in replies:
# Will fail if User does not exist
self.session.query(User).filter_by(id=reply.journalist_id).one()

self.session.close()


class DowngradeTester:
"""
Verify that downgrading from the target migration keeps in place the updates from the migration
since there is no need to add bad data back into the db (the migration is backwards compatible).
"""

NUM_USERS = 20
NUM_SOURCES = 20
NUM_REPLIES = 40

def __init__(self, homedir):
subprocess.check_call(["sqlite3", os.path.join(homedir, "svs.sqlite"), ".databases"])
self.session = db.make_session_maker(homedir)()

def load_data(self):
"""
Load data that has the bug where user.uuid is stored in replies.journalist_id.
"""
for _ in range(self.NUM_SOURCES):
add_source(self.session)

self.session.commit()

for i in range(self.NUM_USERS):
if i == 0:
# As of this migration, the server tells the client that the associated journalist
# of a reply has been deleted by returning "deleted" as the uuid of the associated
# journalist. This gets stored as the jouranlist_id in the replies table.
#
# Make sure to test this case as well.
add_user(self.session, "deleted")
journalist = self.session.query(User).filter_by(uuid="deleted").one()
source_id = random.randint(1, self.NUM_SOURCES - 1)
add_reply(self.session, journalist.id, source_id)
else:
add_user(self.session)

self.session.commit()

# Add replies from randomly-selected journalists to a randomly-selected sources
for _ in range(1, self.NUM_REPLIES):
journalist_id = random.randint(1, self.NUM_USERS - 1)
source_id = random.randint(1, self.NUM_SOURCES)
add_reply(self.session, journalist_id, source_id)

self.session.commit()
self.session.close()

def check_downgrade(self):
"""
Make sure each reply in the replies table has the correct journalist_id stored for the
associated journalist by making sure a User account exists with that journalist id.
"""
replies = self.session.query(Reply).all()
assert len(replies)

for reply in replies:
# Will fail if User does not exist
# self.session.query(User).filter_by(id=reply.journalist_id).one()
journalist = self.session.query(User).filter_by(id=reply.journalist_id).one_or_none()
if not journalist:
assert reply.journalist_id == "allie"

self.session.close()
154 changes: 154 additions & 0 deletions tests/migrations/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# -*- coding: utf-8 -*-
import random
import string
from datetime import datetime
from typing import Optional
from uuid import uuid4

from sqlalchemy import text
from sqlalchemy.orm.session import Session

from securedrop_client.db import DownloadError, Source

random.seed("ᕕ( ᐛ )ᕗ")


def random_bool() -> bool:
return bool(random.getrandbits(1))


def bool_or_none() -> Optional[bool]:
return random.choice([None, True, False])


def random_name() -> str:
len = random.randint(1, 100)
return random_chars(len)


def random_username() -> str:
len = random.randint(3, 64)
return random_chars(len)


def random_chars(len: int, chars: str = string.printable) -> str:
return "".join([random.choice(chars) for _ in range(len)])


def random_ascii_chars(len: int, chars: str = string.ascii_lowercase):
return "".join([random.choice(chars) for _ in range(len)])


def random_datetime(nullable: bool = False):
if nullable and random_bool():
return None
else:
return datetime(
year=random.randint(1, 9999),
month=random.randint(1, 12),
day=random.randint(1, 28),
hour=random.randint(0, 23),
minute=random.randint(0, 59),
second=random.randint(0, 59),
microsecond=random.randint(0, 1000),
)


def add_source(session: Session) -> None:
params = {
"uuid": str(uuid4()),
"journalist_designation": random_chars(50),
"last_updated": random_datetime(nullable=True),
"interaction_count": random.randint(0, 1000),
}
sql = """
INSERT INTO sources (
uuid,
journalist_designation,
last_updated,
interaction_count
)
VALUES (
:uuid,
:journalist_designation,
:last_updated,
:interaction_count
)
"""
session.execute(text(sql), params)


def add_user(session: Session, uuid: Optional[str] = None) -> None:
if not uuid:
journalist_uuid = str(uuid4())
else:
journalist_uuid = uuid

params = {
"uuid": journalist_uuid,
"username": random_username(),
}
sql = """
INSERT INTO users (uuid, username)
VALUES (:uuid, :username)
"""
session.execute(text(sql), params)


def add_reply(session: Session, journalist_id: int, source_id: int) -> None:
is_downloaded = random_bool() if random_bool() else None
is_decrypted = random_bool() if is_downloaded else None

download_errors = session.query(DownloadError).all()
download_error_ids = [error.id for error in download_errors]

content = random_chars(1000) if is_downloaded else None

source = session.query(Source).filter_by(id=source_id).one()

file_counter = len(source.collection) + 1

params = {
"uuid": str(uuid4()),
"journalist_id": journalist_id,
"source_id": source_id,
"filename": random_chars(50) + "-reply.gpg",
"file_counter": file_counter,
"size": random.randint(0, 1024 * 1024 * 500),
"content": content,
"is_downloaded": is_downloaded,
"is_decrypted": is_decrypted,
"download_error_id": random.choice(download_error_ids),
"last_updated": random_datetime(),
}
sql = """
INSERT INTO replies
(
uuid,
journalist_id,
source_id,
filename,
file_counter,
size,
content,
is_downloaded,
is_decrypted,
download_error_id,
last_updated
)
VALUES
(
:uuid,
:journalist_id,
:source_id,
:filename,
:file_counter,
:size,
:content,
:is_downloaded,
:is_decrypted,
:download_error_id,
:last_updated
)
"""
session.execute(text(sql), params)
35 changes: 35 additions & 0 deletions tests/test_alembic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
x.split(".")[0].split("_")[0] for x in os.listdir(MIGRATION_PATH) if x.endswith(".py")
]

DATA_MIGRATIONS = ["a4bf1f58ce69"]

WHITESPACE_REGEX = re.compile(r"\s+")


Expand Down Expand Up @@ -135,6 +137,24 @@ def test_alembic_migration_upgrade(alembic_config, config, migration):
upgrade(alembic_config, mig)


@pytest.mark.parametrize("migration", DATA_MIGRATIONS)
def test_alembic_migration_upgrade_with_data(alembic_config, config, migration, homedir):
"""
Upgrade to one migration before the target migration, load data, then upgrade in order to test
that the upgrade is successful when there is data.
"""
migrations = list_migrations(alembic_config, migration)
if len(migrations) == 1:
return
upgrade(alembic_config, migrations[-2])
mod_name = "tests.migrations.test_{}".format(migration)
mod = __import__(mod_name, fromlist=["UpgradeTester"])
upgrade_tester = mod.UpgradeTester(homedir)
upgrade_tester.load_data()
upgrade(alembic_config, migration)
upgrade_tester.check_upgrade()


@pytest.mark.parametrize("migration", ALL_MIGRATIONS)
def test_alembic_migration_downgrade(alembic_config, config, migration):
# upgrade to the parameterized test case ("head")
Expand All @@ -148,6 +168,21 @@ def test_alembic_migration_downgrade(alembic_config, config, migration):
downgrade(alembic_config, mig)


@pytest.mark.parametrize("migration", DATA_MIGRATIONS)
def test_alembic_migration_downgrade_with_data(alembic_config, config, migration, homedir):
"""
Upgrade to the target migration, load data, then downgrade in order to test that the downgrade
is successful when there is data.
"""
upgrade(alembic_config, migration)
mod_name = "tests.migrations.test_{}".format(migration)
mod = __import__(mod_name, fromlist=["DowngradeTester"])
downgrade_tester = mod.DowngradeTester(homedir)
downgrade_tester.load_data()
downgrade(alembic_config, "-1")
downgrade_tester.check_downgrade()


@pytest.mark.parametrize("migration", ALL_MIGRATIONS)
def test_schema_unchanged_after_up_then_downgrade(alembic_config, tmpdir, migration):
migrations = list_migrations(alembic_config, migration)
Expand Down

0 comments on commit e43fc7c

Please sign in to comment.