Skip to content

Commit

Permalink
Merge branch 'main' into fix-manual-maintenance-mode
Browse files Browse the repository at this point in the history
  • Loading branch information
madwort authored Nov 8, 2024
2 parents 1e12d2e + 3141f1e commit 34fb168
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 27 deletions.
16 changes: 8 additions & 8 deletions jobrunner/job_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,26 +298,26 @@ def delete_files(self, workspace: str, privacy: Privacy, paths: [str]) -> List[s
class NullExecutorAPI(ExecutorAPI):
"""Null implementation of ExecutorAPI."""

def prepare(self, job_definition):
def prepare(self, job_definition): # pragma: nocover
raise NotImplementedError

def execute(self, job_definition):
def execute(self, job_definition): # pragma: nocover
raise NotImplementedError

def finalize(self, job_definition):
def finalize(self, job_definition): # pragma: nocover
raise NotImplementedError

def terminate(self, job_definition):
def terminate(self, job_definition): # pragma: nocover
raise NotImplementedError

def get_status(self, job_definition):
def get_status(self, job_definition): # pragma: nocover
raise NotImplementedError

def get_results(self, job_definition):
def get_results(self, job_definition): # pragma: nocover
raise NotImplementedError

def cleanup(self, job_definition):
def cleanup(self, job_definition): # pragma: nocover
raise NotImplementedError

def delete_files(self, workspace, privacy, paths):
def delete_files(self, workspace, privacy, paths): # pragma: nocover
raise NotImplementedError
43 changes: 27 additions & 16 deletions jobrunner/lib/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,31 @@ def migration(version, sql):
MIGRATIONS[version] = sql


def insert(item):
def generate_insert_sql(item):
table = item.__tablename__
fields = dataclasses.fields(item)
columns = ", ".join(escape(field.name) for field in fields)
placeholders = ", ".join(["?"] * len(fields))
sql = f"INSERT INTO {escape(table)} ({columns}) VALUES({placeholders})"
return sql, fields


def insert(item):
sql, fields = generate_insert_sql(item)

get_connection().execute(sql, encode_field_values(fields, item))


def upsert(item):
assert item.id
table = item.__tablename__
fields = dataclasses.fields(item)
columns = ", ".join(escape(field.name) for field in fields)
placeholders = ", ".join(["?"] * len(fields))
insert_sql, fields = generate_insert_sql(item)

updates = ", ".join(f"{escape(field.name)} = ?" for field in fields)
# Note: technically we update the id on conflict with this approach, which
# is unessecary, but it does not hurt and simplifies updates and params
# is unnecessary, but it does not hurt and simplifies updates and params
# parts of the query.
sql = f"""
INSERT INTO {escape(table)} ({columns}) VALUES({placeholders})
{insert_sql}
ON CONFLICT(id) DO UPDATE SET {updates}
"""
params = encode_field_values(fields, item)
Expand Down Expand Up @@ -102,7 +106,7 @@ def find_where(itemclass, **query_params):
return [itemclass(*decode_field_values(fields, row)) for row in cursor]


def find_all(itemclass):
def find_all(itemclass): # pragma: nocover
return find_where(itemclass)


Expand Down Expand Up @@ -155,15 +159,20 @@ def transaction():
return conn


def filename_or_get_default(filename=None):
if filename is None:
filename = config.DATABASE_FILE
return filename


def get_connection(filename=None):
"""Return the current configured connection."""
# The caching below means we get the same connection to the database every
# time which is done not so much for efficiency as so that we can easily
# implement transaction support without having to explicitly pass round a
# connection object. This is done on a per-thread basis to avoid potential
# threading issues.
if filename is None:
filename = config.DATABASE_FILE
filename = filename_or_get_default(filename)

# Looks icky but is documented `threading.local` usage
cache = CONNECTION_CACHE.__dict__
Expand Down Expand Up @@ -208,8 +217,7 @@ def ensure_valid_db(filename=None, migrations=MIGRATIONS):
# we store migrations in models, so make sure this has been imported to collect them
import jobrunner.models # noqa: F401

if filename is None:
filename = config.DATABASE_FILE
filename = filename_or_get_default(filename)

db_type, db_exists = db_status(filename)
if db_type == "file" and not db_exists:
Expand All @@ -234,8 +242,7 @@ def ensure_db(filename=None, migrations=MIGRATIONS, verbose=False):
# we store migrations in models, so make sure this has been imported to collect them
import jobrunner.models # noqa: F401

if filename is None:
filename = config.DATABASE_FILE
filename = filename_or_get_default(filename)

db_type, db_exists = db_status(filename)

Expand Down Expand Up @@ -293,8 +300,12 @@ def query_params_to_sql(params):
All parameters are implicitly ANDed together, and there's a bit of magic to
handle `field__in=list_of_values` queries, LIKE queries and Enum classes.
"""
if not params:
return "1 = 1", []

parts = []
values = []

for key, value in params.items():
if key.endswith("__in"):
field = key[:-4]
Expand All @@ -308,11 +319,11 @@ def query_params_to_sql(params):
else:
parts.append(f"{escape(key)} = ?")
values.append(value)

# Bit of a hack: convert any Enum instances to their values so we can use
# them in querying
values = [v.value if isinstance(v, Enum) else v for v in values]
if not parts:
parts = ["1 = 1"]

return " AND ".join(parts), values


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ omit = [
]

[tool.coverage.report]
fail_under = 82
fail_under = 84
show_missing = true
skip_covered = true

Expand Down
107 changes: 107 additions & 0 deletions tests/lib/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
from jobrunner.lib.database import (
CONNECTION_CACHE,
MigrationNeeded,
count_where,
ensure_db,
ensure_valid_db,
exists_where,
find_one,
generate_insert_sql,
get_connection,
insert,
migrate_db,
query_params_to_sql,
select_values,
transaction,
update,
upsert,
)
from jobrunner.models import Job, State

Expand All @@ -37,6 +43,48 @@ def test_basic_roundtrip(tmp_work_dir):
assert job.output_spec == j.output_spec


def test_insert_in_transaction_success(tmp_work_dir):
job = Job(
id="foo123",
job_request_id="bar123",
state=State.RUNNING,
output_spec={"hello": [1, 2, 3]},
)

with transaction():
insert(job)
j = find_one(Job, job_request_id__in=["bar123", "baz123"])
assert job.id == j.id
assert job.output_spec == j.output_spec


def test_insert_in_transaction_fail(tmp_work_dir):
job = Job(
id="foo123",
job_request_id="bar123",
state=State.RUNNING,
output_spec={"hello": [1, 2, 3]},
)

with transaction():
insert(job)
conn = get_connection()
conn.execute("ROLLBACK")

with pytest.raises(ValueError):
find_one(Job, job_request_id__in=["bar123", "baz123"])


def test_generate_insert_sql(tmp_work_dir):
job = Job(id="foo123", action="foo")
sql, _ = generate_insert_sql(job)

assert (
sql
== 'INSERT INTO "job" ("id", "job_request_id", "state", "repo_url", "commit", "workspace", "database_name", "action", "action_repo_url", "action_commit", "requires_outputs_from", "wait_for_job_ids", "run_command", "image_id", "output_spec", "outputs", "unmatched_outputs", "status_message", "status_code", "cancelled", "created_at", "updated_at", "started_at", "completed_at", "status_code_updated_at", "trace_context", "level4_excluded_files", "requires_db") VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)'
)


def test_update(tmp_work_dir):
job = Job(id="foo123", action="foo")
insert(job)
Expand All @@ -45,6 +93,20 @@ def test_update(tmp_work_dir):
assert find_one(Job, id="foo123").action == "bar"


def test_upsert_insert(tmp_work_dir):
job = Job(id="foo123", action="bar")
upsert(job)
assert find_one(Job, id="foo123").action == "bar"


def test_upsert_update(tmp_work_dir):
job = Job(id="foo123", action="foo")
insert(job)
job.action = "bar"
upsert(job)
assert find_one(Job, id="foo123").action == "bar"


def test_update_excluding_a_field(tmp_work_dir):
job = Job(id="foo123", action="foo", commit="commit-of-glory")
insert(job)
Expand All @@ -56,6 +118,26 @@ def test_update_excluding_a_field(tmp_work_dir):
assert j.commit == "commit-of-glory"


def test_exists_where(tmp_work_dir):
insert(Job(id="foo123", state=State.PENDING))
insert(Job(id="foo124", state=State.RUNNING))
insert(Job(id="foo125", state=State.FAILED))
job_state_exists = exists_where(Job, state__in=[State.PENDING, State.FAILED])
assert job_state_exists is True
job_id_exists = exists_where(Job, id="foo124")
assert job_id_exists is True


def test_count_where(tmp_work_dir):
insert(Job(id="foo123", state=State.PENDING))
insert(Job(id="foo124", state=State.RUNNING))
insert(Job(id="foo125", state=State.FAILED))
jobs_in_states = count_where(Job, state__in=[State.PENDING, State.FAILED])
assert jobs_in_states == 2
jobs_with_id = count_where(Job, id="foo124")
assert jobs_with_id == 1


def test_select_values(tmp_work_dir):
insert(Job(id="foo123", state=State.PENDING))
insert(Job(id="foo124", state=State.RUNNING))
Expand Down Expand Up @@ -204,3 +286,28 @@ def test_ensure_valid_db(tmp_path):
# does not raise when all is well
conn.execute("PRAGMA user_version=1")
ensure_valid_db(db, {1: "should not run"})


@pytest.mark.parametrize(
"params,expected_sql_string,expected_sql_values",
[
({}, "1 = 1", []),
({"doubutsu": "neko"}, '"doubutsu" = ?', ["neko"]),
({"doubutsu__like": "ne%"}, '"doubutsu" LIKE ?', ["ne%"]),
(
{"doubutsu__in": ["neko", "kitsune", "nezumi"]},
'"doubutsu" IN (?, ?, ?)',
["neko", "kitsune", "nezumi"],
),
(
{"namae": "rosa", "doubutsu__in": ["neko"]},
'"namae" = ? AND "doubutsu" IN (?)',
["rosa", "neko"],
),
({"state": State.RUNNING}, '"state" = ?', ["running"]),
],
)
def test_query_params_to_sql(params, expected_sql_string, expected_sql_values):
sql_string, sql_values = query_params_to_sql(params)
assert sql_string == expected_sql_string
assert sql_values == expected_sql_values
33 changes: 32 additions & 1 deletion tests/lib/test_string_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,39 @@
from jobrunner.lib.string_utils import project_name_from_url
import pytest

from jobrunner.lib.string_utils import project_name_from_url, slugify, tabulate


@pytest.mark.parametrize(
"input_string,slug",
[
("string", "string"),
("neko猫", "neko"),
("string!@#$%^&**()", "string"),
("string_______string-------string string", "string-string-string-string"),
("__string__", "string"),
],
)
def test_slugify(input_string, slug):
assert slugify(input_string) == slug


def test_project_name_from_url():
assert project_name_from_url("https://github.com/opensafely/test1.git") == "test1"
assert project_name_from_url("https://github.com/opensafely/test2/") == "test2"
assert project_name_from_url("/some/local/path/test3/") == "test3"
assert project_name_from_url("C:\\some\\windows\\path\\test4\\") == "test4"


@pytest.mark.parametrize(
"rows,formatted_output",
[
([], ""),
([["one", "two"], ["three", "four"]], "one two \nthree four"),
(
[["verylongword", "b"], ["yeahyeahyeah", "猫猫猫"]],
"verylongword b \nyeahyeahyeah 猫猫猫",
),
],
)
def test_tabulate(rows, formatted_output):
assert tabulate(rows) == formatted_output
24 changes: 23 additions & 1 deletion tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from responses import matchers

from jobrunner import config, queries, sync
from jobrunner.models import JobRequest
from jobrunner.lib.database import find_where
from jobrunner.models import Job, JobRequest
from tests.factories import job_factory, metrics_factory


Expand Down Expand Up @@ -152,3 +153,24 @@ def test_session_request_flags(db, responses):

# if this works, our expected request was generated
sync.api_get("path", params={"backend": "test"})


def test_sync_empty_response(db, monkeypatch, requests_mock):
monkeypatch.setattr(
"jobrunner.config.JOB_SERVER_ENDPOINT", "http://testserver/api/v2/"
)
requests_mock.get(
"http://testserver/api/v2/job-requests/?backend=expectations",
json={
"results": [],
},
)
sync.sync()

# verify we did not post back to job-server
assert requests_mock.last_request.text is None
assert requests_mock.last_request.method == "GET"

# also that we did not create any jobs
jobs = find_where(Job)
assert jobs == []

0 comments on commit 34fb168

Please sign in to comment.