diff --git a/tests/lib/test_database.py b/tests/lib/test_database.py index e197969b..4479c1fe 100644 --- a/tests/lib/test_database.py +++ b/tests/lib/test_database.py @@ -17,6 +17,7 @@ migrate_db, query_params_to_sql, select_values, + transaction, update, upsert, ) @@ -42,6 +43,38 @@ def test_basic_roundtrip(tmp_work_dir): assert job.output_spec == j.output_spec +def test_insert_in_transaction_success(tmp_work_dir): + job = Job( + id="foo123", + job_request_id="bar123", + state=State.RUNNING, + output_spec={"hello": [1, 2, 3]}, + ) + + with transaction(): + insert(job) + j = find_one(Job, job_request_id__in=["bar123", "baz123"]) + assert job.id == j.id + assert job.output_spec == j.output_spec + + +def test_insert_in_transaction_fail(tmp_work_dir): + job = Job( + id="foo123", + job_request_id="bar123", + state=State.RUNNING, + output_spec={"hello": [1, 2, 3]}, + ) + + with transaction(): + insert(job) + conn = get_connection() + conn.execute("ROLLBACK") + + with pytest.raises(ValueError): + find_one(Job, job_request_id__in=["bar123", "baz123"]) + + def test_generate_insert_sql(tmp_work_dir): job = Job(id="foo123", action="foo") sql, _ = generate_insert_sql(job)