From 1344cfac43a1920c596b0e8718ca0567889e697b Mon Sep 17 00:00:00 2001 From: "Erlend E. Aasland" Date: Thu, 17 Aug 2023 08:45:48 +0200 Subject: [PATCH] gh-105539: Explict resource management for connection objects in sqlite3 tests (#108017) - Use memory_database() helper - Move test utility functions to util.py - Add convenience memory database mixin - Add check() helper for closed connection tests --- Lib/test/test_sqlite3/test_backup.py | 52 ++-- Lib/test/test_sqlite3/test_dbapi.py | 253 +++++++++----------- Lib/test/test_sqlite3/test_dump.py | 10 +- Lib/test/test_sqlite3/test_factory.py | 58 ++--- Lib/test/test_sqlite3/test_hooks.py | 79 +++--- Lib/test/test_sqlite3/test_regression.py | 123 +++++----- Lib/test/test_sqlite3/test_transactions.py | 33 +-- Lib/test/test_sqlite3/test_userfunctions.py | 74 ++---- Lib/test/test_sqlite3/util.py | 78 ++++++ 9 files changed, 373 insertions(+), 387 deletions(-) create mode 100644 Lib/test/test_sqlite3/util.py diff --git a/Lib/test/test_sqlite3/test_backup.py b/Lib/test/test_sqlite3/test_backup.py index 87ab29c54d65e2..4584d976bce0c6 100644 --- a/Lib/test/test_sqlite3/test_backup.py +++ b/Lib/test/test_sqlite3/test_backup.py @@ -1,6 +1,8 @@ import sqlite3 as sqlite import unittest +from .util import memory_database + class BackupTests(unittest.TestCase): def setUp(self): @@ -32,32 +34,32 @@ def test_bad_target_same_connection(self): self.cx.backup(self.cx) def test_bad_target_closed_connection(self): - bck = sqlite.connect(':memory:') - bck.close() - with self.assertRaises(sqlite.ProgrammingError): - self.cx.backup(bck) + with memory_database() as bck: + bck.close() + with self.assertRaises(sqlite.ProgrammingError): + self.cx.backup(bck) def test_bad_source_closed_connection(self): - bck = sqlite.connect(':memory:') - source = sqlite.connect(":memory:") - source.close() - with self.assertRaises(sqlite.ProgrammingError): - source.backup(bck) + with memory_database() as bck: + source = sqlite.connect(":memory:") + source.close() + with self.assertRaises(sqlite.ProgrammingError): + source.backup(bck) def test_bad_target_in_transaction(self): - bck = sqlite.connect(':memory:') - bck.execute('CREATE TABLE bar (key INTEGER)') - bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)]) - with self.assertRaises(sqlite.OperationalError) as cm: - self.cx.backup(bck) + with memory_database() as bck: + bck.execute('CREATE TABLE bar (key INTEGER)') + bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)]) + with self.assertRaises(sqlite.OperationalError) as cm: + self.cx.backup(bck) def test_keyword_only_args(self): with self.assertRaises(TypeError): - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, 1) def test_simple(self): - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck) self.verify_backup(bck) @@ -67,7 +69,7 @@ def test_progress(self): def progress(status, remaining, total): journal.append(status) - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, pages=1, progress=progress) self.verify_backup(bck) @@ -81,7 +83,7 @@ def test_progress_all_pages_at_once_1(self): def progress(status, remaining, total): journal.append(remaining) - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, progress=progress) self.verify_backup(bck) @@ -94,7 +96,7 @@ def test_progress_all_pages_at_once_2(self): def progress(status, remaining, total): journal.append(remaining) - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, pages=-1, progress=progress) self.verify_backup(bck) @@ -103,7 +105,7 @@ def progress(status, remaining, total): def test_non_callable_progress(self): with self.assertRaises(TypeError) as cm: - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, pages=1, progress='bar') self.assertEqual(str(cm.exception), 'progress argument must be a callable') @@ -116,7 +118,7 @@ def progress(status, remaining, total): self.cx.commit() journal.append(remaining) - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, pages=1, progress=progress) self.verify_backup(bck) @@ -140,12 +142,12 @@ def progress(status, remaining, total): self.assertEqual(str(err.exception), 'nearly out of space') def test_database_source_name(self): - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, name='main') - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, name='temp') with self.assertRaises(sqlite.OperationalError) as cm: - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, name='non-existing') self.assertIn("unknown database", str(cm.exception)) @@ -153,7 +155,7 @@ def test_database_source_name(self): self.cx.execute('CREATE TABLE attached_db.foo (key INTEGER)') self.cx.executemany('INSERT INTO attached_db.foo (key) VALUES (?)', [(3,), (4,)]) self.cx.commit() - with sqlite.connect(':memory:') as bck: + with memory_database() as bck: self.cx.backup(bck, name='attached_db') self.verify_backup(bck) diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py index c9a9e1353938c6..df3c2ea8d1dbda 100644 --- a/Lib/test/test_sqlite3/test_dbapi.py +++ b/Lib/test/test_sqlite3/test_dbapi.py @@ -33,26 +33,13 @@ SHORT_TIMEOUT, check_disallow_instantiation, requires_subprocess, is_emscripten, is_wasi ) +from test.support import gc_collect from test.support import threading_helper from _testcapi import INT_MAX, ULLONG_MAX from os import SEEK_SET, SEEK_CUR, SEEK_END from test.support.os_helper import TESTFN, TESTFN_UNDECODABLE, unlink, temp_dir, FakePath - -# Helper for temporary memory databases -def memory_database(*args, **kwargs): - cx = sqlite.connect(":memory:", *args, **kwargs) - return contextlib.closing(cx) - - -# Temporarily limit a database connection parameter -@contextlib.contextmanager -def cx_limit(cx, category=sqlite.SQLITE_LIMIT_SQL_LENGTH, limit=128): - try: - _prev = cx.setlimit(category, limit) - yield limit - finally: - cx.setlimit(category, _prev) +from .util import memory_database, cx_limit class ModuleTests(unittest.TestCase): @@ -326,9 +313,9 @@ def test_extended_error_code_on_exception(self): self.assertEqual(exc.sqlite_errorname, "SQLITE_CONSTRAINT_CHECK") def test_disallow_instantiation(self): - cx = sqlite.connect(":memory:") - check_disallow_instantiation(self, type(cx("select 1"))) - check_disallow_instantiation(self, sqlite.Blob) + with memory_database() as cx: + check_disallow_instantiation(self, type(cx("select 1"))) + check_disallow_instantiation(self, sqlite.Blob) def test_complete_statement(self): self.assertFalse(sqlite.complete_statement("select t")) @@ -342,6 +329,7 @@ def setUp(self): cu = self.cx.cursor() cu.execute("create table test(id integer primary key, name text)") cu.execute("insert into test(name) values (?)", ("foo",)) + cu.close() def tearDown(self): self.cx.close() @@ -412,21 +400,22 @@ def test_exceptions(self): def test_in_transaction(self): # Can't use db from setUp because we want to test initial state. - cx = sqlite.connect(":memory:") - cu = cx.cursor() - self.assertEqual(cx.in_transaction, False) - cu.execute("create table transactiontest(id integer primary key, name text)") - self.assertEqual(cx.in_transaction, False) - cu.execute("insert into transactiontest(name) values (?)", ("foo",)) - self.assertEqual(cx.in_transaction, True) - cu.execute("select name from transactiontest where name=?", ["foo"]) - row = cu.fetchone() - self.assertEqual(cx.in_transaction, True) - cx.commit() - self.assertEqual(cx.in_transaction, False) - cu.execute("select name from transactiontest where name=?", ["foo"]) - row = cu.fetchone() - self.assertEqual(cx.in_transaction, False) + with memory_database() as cx: + cu = cx.cursor() + self.assertEqual(cx.in_transaction, False) + cu.execute("create table transactiontest(id integer primary key, name text)") + self.assertEqual(cx.in_transaction, False) + cu.execute("insert into transactiontest(name) values (?)", ("foo",)) + self.assertEqual(cx.in_transaction, True) + cu.execute("select name from transactiontest where name=?", ["foo"]) + row = cu.fetchone() + self.assertEqual(cx.in_transaction, True) + cx.commit() + self.assertEqual(cx.in_transaction, False) + cu.execute("select name from transactiontest where name=?", ["foo"]) + row = cu.fetchone() + self.assertEqual(cx.in_transaction, False) + cu.close() def test_in_transaction_ro(self): with self.assertRaises(AttributeError): @@ -450,10 +439,9 @@ def test_connection_exceptions(self): self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc)) def test_interrupt_on_closed_db(self): - cx = sqlite.connect(":memory:") - cx.close() + self.cx.close() with self.assertRaises(sqlite.ProgrammingError): - cx.interrupt() + self.cx.interrupt() def test_interrupt(self): self.assertIsNone(self.cx.interrupt()) @@ -521,29 +509,29 @@ def test_connection_init_good_isolation_levels(self): self.assertEqual(cx.isolation_level, level) def test_connection_reinit(self): - db = ":memory:" - cx = sqlite.connect(db) - cx.text_factory = bytes - cx.row_factory = sqlite.Row - cu = cx.cursor() - cu.execute("create table foo (bar)") - cu.executemany("insert into foo (bar) values (?)", - ((str(v),) for v in range(4))) - cu.execute("select bar from foo") - - rows = [r for r in cu.fetchmany(2)] - self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows)) - self.assertEqual([r[0] for r in rows], [b"0", b"1"]) - - cx.__init__(db) - cx.execute("create table foo (bar)") - cx.executemany("insert into foo (bar) values (?)", - ((v,) for v in ("a", "b", "c", "d"))) - - # This uses the old database, old row factory, but new text factory - rows = [r for r in cu.fetchall()] - self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows)) - self.assertEqual([r[0] for r in rows], ["2", "3"]) + with memory_database() as cx: + cx.text_factory = bytes + cx.row_factory = sqlite.Row + cu = cx.cursor() + cu.execute("create table foo (bar)") + cu.executemany("insert into foo (bar) values (?)", + ((str(v),) for v in range(4))) + cu.execute("select bar from foo") + + rows = [r for r in cu.fetchmany(2)] + self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows)) + self.assertEqual([r[0] for r in rows], [b"0", b"1"]) + + cx.__init__(":memory:") + cx.execute("create table foo (bar)") + cx.executemany("insert into foo (bar) values (?)", + ((v,) for v in ("a", "b", "c", "d"))) + + # This uses the old database, old row factory, but new text factory + rows = [r for r in cu.fetchall()] + self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows)) + self.assertEqual([r[0] for r in rows], ["2", "3"]) + cu.close() def test_connection_bad_reinit(self): cx = sqlite.connect(":memory:") @@ -591,11 +579,11 @@ def test_connect_positional_arguments(self): "parameters in Python 3.15." ) with self.assertWarnsRegex(DeprecationWarning, regex) as cm: - sqlite.connect(":memory:", 1.0) + cx = sqlite.connect(":memory:", 1.0) + cx.close() self.assertEqual(cm.filename, __file__) - class UninitialisedConnectionTests(unittest.TestCase): def setUp(self): self.cx = sqlite.Connection.__new__(sqlite.Connection) @@ -1571,12 +1559,12 @@ def run(con, err): except sqlite.Error: err.append("multi-threading not allowed") - con = sqlite.connect(":memory:", check_same_thread=False) - err = [] - t = threading.Thread(target=run, kwargs={"con": con, "err": err}) - t.start() - t.join() - self.assertEqual(len(err), 0, "\n".join(err)) + with memory_database(check_same_thread=False) as con: + err = [] + t = threading.Thread(target=run, kwargs={"con": con, "err": err}) + t.start() + t.join() + self.assertEqual(len(err), 0, "\n".join(err)) class ConstructorTests(unittest.TestCase): @@ -1602,9 +1590,16 @@ def test_binary(self): b = sqlite.Binary(b"\0'") class ExtensionTests(unittest.TestCase): + def setUp(self): + self.con = sqlite.connect(":memory:") + self.cur = self.con.cursor() + + def tearDown(self): + self.cur.close() + self.con.close() + def test_script_string_sql(self): - con = sqlite.connect(":memory:") - cur = con.cursor() + cur = self.cur cur.executescript(""" -- bla bla /* a stupid comment */ @@ -1616,40 +1611,40 @@ def test_script_string_sql(self): self.assertEqual(res, 5) def test_script_syntax_error(self): - con = sqlite.connect(":memory:") - cur = con.cursor() with self.assertRaises(sqlite.OperationalError): - cur.executescript("create table test(x); asdf; create table test2(x)") + self.cur.executescript(""" + CREATE TABLE test(x); + asdf; + CREATE TABLE test2(x) + """) def test_script_error_normal(self): - con = sqlite.connect(":memory:") - cur = con.cursor() with self.assertRaises(sqlite.OperationalError): - cur.executescript("create table test(sadfsadfdsa); select foo from hurz;") + self.cur.executescript(""" + CREATE TABLE test(sadfsadfdsa); + SELECT foo FROM hurz; + """) def test_cursor_executescript_as_bytes(self): - con = sqlite.connect(":memory:") - cur = con.cursor() with self.assertRaises(TypeError): - cur.executescript(b"create table test(foo); insert into test(foo) values (5);") + self.cur.executescript(b""" + CREATE TABLE test(foo); + INSERT INTO test(foo) VALUES (5); + """) def test_cursor_executescript_with_null_characters(self): - con = sqlite.connect(":memory:") - cur = con.cursor() with self.assertRaises(ValueError): - cur.executescript(""" - create table a(i);\0 - insert into a(i) values (5); - """) + self.cur.executescript(""" + CREATE TABLE a(i);\0 + INSERT INTO a(i) VALUES (5); + """) def test_cursor_executescript_with_surrogates(self): - con = sqlite.connect(":memory:") - cur = con.cursor() with self.assertRaises(UnicodeEncodeError): - cur.executescript(""" - create table a(s); - insert into a(s) values ('\ud8ff'); - """) + self.cur.executescript(""" + CREATE TABLE a(s); + INSERT INTO a(s) VALUES ('\ud8ff'); + """) def test_cursor_executescript_too_large_script(self): msg = "query string is too large" @@ -1659,19 +1654,18 @@ def test_cursor_executescript_too_large_script(self): cx.executescript("select 'too large'".ljust(lim+1)) def test_cursor_executescript_tx_control(self): - con = sqlite.connect(":memory:") + con = self.con con.execute("begin") self.assertTrue(con.in_transaction) con.executescript("select 1") self.assertFalse(con.in_transaction) def test_connection_execute(self): - con = sqlite.connect(":memory:") - result = con.execute("select 5").fetchone()[0] + result = self.con.execute("select 5").fetchone()[0] self.assertEqual(result, 5, "Basic test of Connection.execute") def test_connection_executemany(self): - con = sqlite.connect(":memory:") + con = self.con con.execute("create table test(foo)") con.executemany("insert into test(foo) values (?)", [(3,), (4,)]) result = con.execute("select foo from test order by foo").fetchall() @@ -1679,47 +1673,44 @@ def test_connection_executemany(self): self.assertEqual(result[1][0], 4, "Basic test of Connection.executemany") def test_connection_executescript(self): - con = sqlite.connect(":memory:") - con.executescript("create table test(foo); insert into test(foo) values (5);") + con = self.con + con.executescript(""" + CREATE TABLE test(foo); + INSERT INTO test(foo) VALUES (5); + """) result = con.execute("select foo from test").fetchone()[0] self.assertEqual(result, 5, "Basic test of Connection.executescript") + class ClosedConTests(unittest.TestCase): + def check(self, fn, *args, **kwds): + regex = "Cannot operate on a closed database." + with self.assertRaisesRegex(sqlite.ProgrammingError, regex): + fn(*args, **kwds) + + def setUp(self): + self.con = sqlite.connect(":memory:") + self.cur = self.con.cursor() + self.con.close() + def test_closed_con_cursor(self): - con = sqlite.connect(":memory:") - con.close() - with self.assertRaises(sqlite.ProgrammingError): - cur = con.cursor() + self.check(self.con.cursor) def test_closed_con_commit(self): - con = sqlite.connect(":memory:") - con.close() - with self.assertRaises(sqlite.ProgrammingError): - con.commit() + self.check(self.con.commit) def test_closed_con_rollback(self): - con = sqlite.connect(":memory:") - con.close() - with self.assertRaises(sqlite.ProgrammingError): - con.rollback() + self.check(self.con.rollback) def test_closed_cur_execute(self): - con = sqlite.connect(":memory:") - cur = con.cursor() - con.close() - with self.assertRaises(sqlite.ProgrammingError): - cur.execute("select 4") + self.check(self.cur.execute, "select 4") def test_closed_create_function(self): - con = sqlite.connect(":memory:") - con.close() - def f(x): return 17 - with self.assertRaises(sqlite.ProgrammingError): - con.create_function("foo", 1, f) + def f(x): + return 17 + self.check(self.con.create_function, "foo", 1, f) def test_closed_create_aggregate(self): - con = sqlite.connect(":memory:") - con.close() class Agg: def __init__(self): pass @@ -1727,29 +1718,21 @@ def step(self, x): pass def finalize(self): return 17 - with self.assertRaises(sqlite.ProgrammingError): - con.create_aggregate("foo", 1, Agg) + self.check(self.con.create_aggregate, "foo", 1, Agg) def test_closed_set_authorizer(self): - con = sqlite.connect(":memory:") - con.close() def authorizer(*args): return sqlite.DENY - with self.assertRaises(sqlite.ProgrammingError): - con.set_authorizer(authorizer) + self.check(self.con.set_authorizer, authorizer) def test_closed_set_progress_callback(self): - con = sqlite.connect(":memory:") - con.close() - def progress(): pass - with self.assertRaises(sqlite.ProgrammingError): - con.set_progress_handler(progress, 100) + def progress(): + pass + self.check(self.con.set_progress_handler, progress, 100) def test_closed_call(self): - con = sqlite.connect(":memory:") - con.close() - with self.assertRaises(sqlite.ProgrammingError): - con() + self.check(self.con) + class ClosedCurTests(unittest.TestCase): def test_closed(self): diff --git a/Lib/test/test_sqlite3/test_dump.py b/Lib/test/test_sqlite3/test_dump.py index d0c24b9c60e613..5f6811fb5cc0a5 100644 --- a/Lib/test/test_sqlite3/test_dump.py +++ b/Lib/test/test_sqlite3/test_dump.py @@ -2,16 +2,12 @@ import unittest import sqlite3 as sqlite -from .test_dbapi import memory_database +from .util import memory_database +from .util import MemoryDatabaseMixin -class DumpTests(unittest.TestCase): - def setUp(self): - self.cx = sqlite.connect(":memory:") - self.cu = self.cx.cursor() - def tearDown(self): - self.cx.close() +class DumpTests(MemoryDatabaseMixin, unittest.TestCase): def test_table_dump(self): expected_sqls = [ diff --git a/Lib/test/test_sqlite3/test_factory.py b/Lib/test/test_sqlite3/test_factory.py index d63589483e1042..a7c4417862aff7 100644 --- a/Lib/test/test_sqlite3/test_factory.py +++ b/Lib/test/test_sqlite3/test_factory.py @@ -24,6 +24,9 @@ import sqlite3 as sqlite from collections.abc import Sequence +from .util import memory_database +from .util import MemoryDatabaseMixin + def dict_factory(cursor, row): d = {} @@ -45,10 +48,12 @@ class OkFactory(sqlite.Connection): def __init__(self, *args, **kwargs): sqlite.Connection.__init__(self, *args, **kwargs) - for factory in DefectFactory, OkFactory: - with self.subTest(factory=factory): - con = sqlite.connect(":memory:", factory=factory) - self.assertIsInstance(con, factory) + with memory_database(factory=OkFactory) as con: + self.assertIsInstance(con, OkFactory) + regex = "Base Connection.__init__ not called." + with self.assertRaisesRegex(sqlite.ProgrammingError, regex): + with memory_database(factory=DefectFactory) as con: + self.assertIsInstance(con, DefectFactory) def test_connection_factory_relayed_call(self): # gh-95132: keyword args must not be passed as positional args @@ -57,9 +62,9 @@ def __init__(self, *args, **kwargs): kwargs["isolation_level"] = None super(Factory, self).__init__(*args, **kwargs) - con = sqlite.connect(":memory:", factory=Factory) - self.assertIsNone(con.isolation_level) - self.assertIsInstance(con, Factory) + with memory_database(factory=Factory) as con: + self.assertIsNone(con.isolation_level) + self.assertIsInstance(con, Factory) def test_connection_factory_as_positional_arg(self): class Factory(sqlite.Connection): @@ -74,18 +79,13 @@ def __init__(self, *args, **kwargs): r"parameters in Python 3.15." ) with self.assertWarnsRegex(DeprecationWarning, regex) as cm: - con = sqlite.connect(":memory:", 5.0, 0, None, True, Factory) + with memory_database(5.0, 0, None, True, Factory) as con: + self.assertIsNone(con.isolation_level) + self.assertIsInstance(con, Factory) self.assertEqual(cm.filename, __file__) - self.assertIsNone(con.isolation_level) - self.assertIsInstance(con, Factory) -class CursorFactoryTests(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") - - def tearDown(self): - self.con.close() +class CursorFactoryTests(MemoryDatabaseMixin, unittest.TestCase): def test_is_instance(self): cur = self.con.cursor() @@ -103,9 +103,8 @@ def test_invalid_factory(self): # invalid callable returning non-cursor self.assertRaises(TypeError, self.con.cursor, lambda con: None) -class RowFactoryTestsBackwardsCompat(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") + +class RowFactoryTestsBackwardsCompat(MemoryDatabaseMixin, unittest.TestCase): def test_is_produced_by_factory(self): cur = self.con.cursor(factory=MyCursor) @@ -114,12 +113,8 @@ def test_is_produced_by_factory(self): self.assertIsInstance(row, dict) cur.close() - def tearDown(self): - self.con.close() -class RowFactoryTests(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") +class RowFactoryTests(MemoryDatabaseMixin, unittest.TestCase): def test_custom_factory(self): self.con.row_factory = lambda cur, row: list(row) @@ -265,12 +260,8 @@ class FakeCursor(str): self.assertRaises(TypeError, self.con.cursor, FakeCursor) self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ()) - def tearDown(self): - self.con.close() -class TextFactoryTests(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") +class TextFactoryTests(MemoryDatabaseMixin, unittest.TestCase): def test_unicode(self): austria = "Österreich" @@ -291,15 +282,17 @@ def test_custom(self): self.assertEqual(type(row[0]), str, "type of row[0] must be unicode") self.assertTrue(row[0].endswith("reich"), "column must contain original data") - def tearDown(self): - self.con.close() class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase): + def setUp(self): self.con = sqlite.connect(":memory:") self.con.execute("create table test (value text)") self.con.execute("insert into test (value) values (?)", ("a\x00b",)) + def tearDown(self): + self.con.close() + def test_string(self): # text_factory defaults to str row = self.con.execute("select value from test").fetchone() @@ -325,9 +318,6 @@ def test_custom(self): self.assertIs(type(row[0]), bytes) self.assertEqual(row[0], b"a\x00b") - def tearDown(self): - self.con.close() - if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_sqlite3/test_hooks.py b/Lib/test/test_sqlite3/test_hooks.py index 89230c08cc9143..33f0af99532a10 100644 --- a/Lib/test/test_sqlite3/test_hooks.py +++ b/Lib/test/test_sqlite3/test_hooks.py @@ -26,34 +26,31 @@ from test.support.os_helper import TESTFN, unlink -from test.test_sqlite3.test_dbapi import memory_database, cx_limit -from test.test_sqlite3.test_userfunctions import with_tracebacks +from .util import memory_database, cx_limit, with_tracebacks +from .util import MemoryDatabaseMixin -class CollationTests(unittest.TestCase): +class CollationTests(MemoryDatabaseMixin, unittest.TestCase): + def test_create_collation_not_string(self): - con = sqlite.connect(":memory:") with self.assertRaises(TypeError): - con.create_collation(None, lambda x, y: (x > y) - (x < y)) + self.con.create_collation(None, lambda x, y: (x > y) - (x < y)) def test_create_collation_not_callable(self): - con = sqlite.connect(":memory:") with self.assertRaises(TypeError) as cm: - con.create_collation("X", 42) + self.con.create_collation("X", 42) self.assertEqual(str(cm.exception), 'parameter must be callable') def test_create_collation_not_ascii(self): - con = sqlite.connect(":memory:") - con.create_collation("collä", lambda x, y: (x > y) - (x < y)) + self.con.create_collation("collä", lambda x, y: (x > y) - (x < y)) def test_create_collation_bad_upper(self): class BadUpperStr(str): def upper(self): return None - con = sqlite.connect(":memory:") mycoll = lambda x, y: -((x > y) - (x < y)) - con.create_collation(BadUpperStr("mycoll"), mycoll) - result = con.execute(""" + self.con.create_collation(BadUpperStr("mycoll"), mycoll) + result = self.con.execute(""" select x from ( select 'a' as x union @@ -68,8 +65,7 @@ def mycoll(x, y): # reverse order return -((x > y) - (x < y)) - con = sqlite.connect(":memory:") - con.create_collation("mycoll", mycoll) + self.con.create_collation("mycoll", mycoll) sql = """ select x from ( select 'a' as x @@ -79,21 +75,20 @@ def mycoll(x, y): select 'c' as x ) order by x collate mycoll """ - result = con.execute(sql).fetchall() + result = self.con.execute(sql).fetchall() self.assertEqual(result, [('c',), ('b',), ('a',)], msg='the expected order was not returned') - con.create_collation("mycoll", None) + self.con.create_collation("mycoll", None) with self.assertRaises(sqlite.OperationalError) as cm: - result = con.execute(sql).fetchall() + result = self.con.execute(sql).fetchall() self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') def test_collation_returns_large_integer(self): def mycoll(x, y): # reverse order return -((x > y) - (x < y)) * 2**32 - con = sqlite.connect(":memory:") - con.create_collation("mycoll", mycoll) + self.con.create_collation("mycoll", mycoll) sql = """ select x from ( select 'a' as x @@ -103,7 +98,7 @@ def mycoll(x, y): select 'c' as x ) order by x collate mycoll """ - result = con.execute(sql).fetchall() + result = self.con.execute(sql).fetchall() self.assertEqual(result, [('c',), ('b',), ('a',)], msg="the expected order was not returned") @@ -112,7 +107,7 @@ def test_collation_register_twice(self): Register two different collation functions under the same name. Verify that the last one is actually used. """ - con = sqlite.connect(":memory:") + con = self.con con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y))) result = con.execute(""" @@ -126,25 +121,26 @@ def test_deregister_collation(self): Register a collation, then deregister it. Make sure an error is raised if we try to use it. """ - con = sqlite.connect(":memory:") + con = self.con con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) con.create_collation("mycoll", None) with self.assertRaises(sqlite.OperationalError) as cm: con.execute("select 'a' as x union select 'b' as x order by x collate mycoll") self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') -class ProgressTests(unittest.TestCase): + +class ProgressTests(MemoryDatabaseMixin, unittest.TestCase): + def test_progress_handler_used(self): """ Test that the progress handler is invoked once it is set. """ - con = sqlite.connect(":memory:") progress_calls = [] def progress(): progress_calls.append(None) return 0 - con.set_progress_handler(progress, 1) - con.execute(""" + self.con.set_progress_handler(progress, 1) + self.con.execute(""" create table foo(a, b) """) self.assertTrue(progress_calls) @@ -153,7 +149,7 @@ def test_opcode_count(self): """ Test that the opcode argument is respected. """ - con = sqlite.connect(":memory:") + con = self.con progress_calls = [] def progress(): progress_calls.append(None) @@ -176,11 +172,10 @@ def test_cancel_operation(self): """ Test that returning a non-zero value stops the operation in progress. """ - con = sqlite.connect(":memory:") def progress(): return 1 - con.set_progress_handler(progress, 1) - curs = con.cursor() + self.con.set_progress_handler(progress, 1) + curs = self.con.cursor() self.assertRaises( sqlite.OperationalError, curs.execute, @@ -190,7 +185,7 @@ def test_clear_handler(self): """ Test that setting the progress handler to None clears the previously set handler. """ - con = sqlite.connect(":memory:") + con = self.con action = 0 def progress(): nonlocal action @@ -203,31 +198,30 @@ def progress(): @with_tracebacks(ZeroDivisionError, name="bad_progress") def test_error_in_progress_handler(self): - con = sqlite.connect(":memory:") def bad_progress(): 1 / 0 - con.set_progress_handler(bad_progress, 1) + self.con.set_progress_handler(bad_progress, 1) with self.assertRaises(sqlite.OperationalError): - con.execute(""" + self.con.execute(""" create table foo(a, b) """) @with_tracebacks(ZeroDivisionError, name="bad_progress") def test_error_in_progress_handler_result(self): - con = sqlite.connect(":memory:") class BadBool: def __bool__(self): 1 / 0 def bad_progress(): return BadBool() - con.set_progress_handler(bad_progress, 1) + self.con.set_progress_handler(bad_progress, 1) with self.assertRaises(sqlite.OperationalError): - con.execute(""" + self.con.execute(""" create table foo(a, b) """) -class TraceCallbackTests(unittest.TestCase): +class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase): + @contextlib.contextmanager def check_stmt_trace(self, cx, expected): try: @@ -242,12 +236,11 @@ def test_trace_callback_used(self): """ Test that the trace callback is invoked once it is set. """ - con = sqlite.connect(":memory:") traced_statements = [] def trace(statement): traced_statements.append(statement) - con.set_trace_callback(trace) - con.execute("create table foo(a, b)") + self.con.set_trace_callback(trace) + self.con.execute("create table foo(a, b)") self.assertTrue(traced_statements) self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) @@ -255,7 +248,7 @@ def test_clear_trace_callback(self): """ Test that setting the trace callback to None clears the previously set callback. """ - con = sqlite.connect(":memory:") + con = self.con traced_statements = [] def trace(statement): traced_statements.append(statement) @@ -269,7 +262,7 @@ def test_unicode_content(self): Test that the statement can contain unicode literals. """ unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' - con = sqlite.connect(":memory:") + con = self.con traced_statements = [] def trace(statement): traced_statements.append(statement) diff --git a/Lib/test/test_sqlite3/test_regression.py b/Lib/test/test_sqlite3/test_regression.py index 7e8221e7227e6e..db4e13222da9da 100644 --- a/Lib/test/test_sqlite3/test_regression.py +++ b/Lib/test/test_sqlite3/test_regression.py @@ -28,15 +28,12 @@ from test import support from unittest.mock import patch -from test.test_sqlite3.test_dbapi import memory_database, cx_limit +from .util import memory_database, cx_limit +from .util import MemoryDatabaseMixin -class RegressionTests(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") - def tearDown(self): - self.con.close() +class RegressionTests(MemoryDatabaseMixin, unittest.TestCase): def test_pragma_user_version(self): # This used to crash pysqlite because this pragma command returns NULL for the column name @@ -45,28 +42,24 @@ def test_pragma_user_version(self): def test_pragma_schema_version(self): # This still crashed pysqlite <= 2.2.1 - con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES) - try: + with memory_database(detect_types=sqlite.PARSE_COLNAMES) as con: cur = self.con.cursor() cur.execute("pragma schema_version") - finally: - cur.close() - con.close() def test_statement_reset(self): # pysqlite 2.1.0 to 2.2.0 have the problem that not all statements are # reset before a rollback, but only those that are still in the # statement cache. The others are not accessible from the connection object. - con = sqlite.connect(":memory:", cached_statements=5) - cursors = [con.cursor() for x in range(5)] - cursors[0].execute("create table test(x)") - for i in range(10): - cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)]) + with memory_database(cached_statements=5) as con: + cursors = [con.cursor() for x in range(5)] + cursors[0].execute("create table test(x)") + for i in range(10): + cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)]) - for i in range(5): - cursors[i].execute(" " * i + "select x from test") + for i in range(5): + cursors[i].execute(" " * i + "select x from test") - con.rollback() + con.rollback() def test_column_name_with_spaces(self): cur = self.con.cursor() @@ -81,17 +74,15 @@ def test_statement_finalization_on_close_db(self): # cache when closing the database. statements that were still # referenced in cursors weren't closed and could provoke " # "OperationalError: Unable to close due to unfinalised statements". - con = sqlite.connect(":memory:") cursors = [] # default statement cache size is 100 for i in range(105): - cur = con.cursor() + cur = self.con.cursor() cursors.append(cur) cur.execute("select 1 x union select " + str(i)) - con.close() def test_on_conflict_rollback(self): - con = sqlite.connect(":memory:") + con = self.con con.execute("create table foo(x, unique(x) on conflict rollback)") con.execute("insert into foo(x) values (1)") try: @@ -126,16 +117,16 @@ def test_type_map_usage(self): a statement. This test exhibits the problem. """ SELECT = "select * from foo" - con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES) - cur = con.cursor() - cur.execute("create table foo(bar timestamp)") - with self.assertWarnsRegex(DeprecationWarning, "adapter"): - cur.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),)) - cur.execute(SELECT) - cur.execute("drop table foo") - cur.execute("create table foo(bar integer)") - cur.execute("insert into foo(bar) values (5)") - cur.execute(SELECT) + with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con: + cur = con.cursor() + cur.execute("create table foo(bar timestamp)") + with self.assertWarnsRegex(DeprecationWarning, "adapter"): + cur.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),)) + cur.execute(SELECT) + cur.execute("drop table foo") + cur.execute("create table foo(bar integer)") + cur.execute("insert into foo(bar) values (5)") + cur.execute(SELECT) def test_bind_mutating_list(self): # Issue41662: Crash when mutate a list of parameters during iteration. @@ -144,11 +135,11 @@ def __conform__(self, protocol): parameters.clear() return "..." parameters = [X(), 0] - con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES) - con.execute("create table foo(bar X, baz integer)") - # Should not crash - with self.assertRaises(IndexError): - con.execute("insert into foo(bar, baz) values (?, ?)", parameters) + with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con: + con.execute("create table foo(bar X, baz integer)") + # Should not crash + with self.assertRaises(IndexError): + con.execute("insert into foo(bar, baz) values (?, ?)", parameters) def test_error_msg_decode_error(self): # When porting the module to Python 3.0, the error message about @@ -173,7 +164,7 @@ def upper(self): def __del__(self): con.isolation_level = "" - con = sqlite.connect(":memory:") + con = self.con con.isolation_level = None for level in "", "DEFERRED", "IMMEDIATE", "EXCLUSIVE": with self.subTest(level=level): @@ -204,8 +195,7 @@ class Cursor(sqlite.Cursor): def __init__(self, con): pass - con = sqlite.connect(":memory:") - cur = Cursor(con) + cur = Cursor(self.con) with self.assertRaises(sqlite.ProgrammingError): cur.execute("select 4+5").fetchall() with self.assertRaisesRegex(sqlite.ProgrammingError, @@ -238,7 +228,9 @@ def test_auto_commit(self): 2.5.3 introduced a regression so that these could no longer be created. """ - con = sqlite.connect(":memory:", isolation_level=None) + with memory_database(isolation_level=None) as con: + self.assertIsNone(con.isolation_level) + self.assertFalse(con.in_transaction) def test_pragma_autocommit(self): """ @@ -273,9 +265,7 @@ def test_recursive_cursor_use(self): Recursively using a cursor, such as when reusing it from a generator led to segfaults. Now we catch recursive cursor usage and raise a ProgrammingError. """ - con = sqlite.connect(":memory:") - - cur = con.cursor() + cur = self.con.cursor() cur.execute("create table a (bar)") cur.execute("create table b (baz)") @@ -295,29 +285,30 @@ def test_convert_timestamp_microsecond_padding(self): since the microsecond string "456" actually represents "456000". """ - con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES) - cur = con.cursor() - cur.execute("CREATE TABLE t (x TIMESTAMP)") + with memory_database(detect_types=sqlite.PARSE_DECLTYPES) as con: + cur = con.cursor() + cur.execute("CREATE TABLE t (x TIMESTAMP)") - # Microseconds should be 456000 - cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')") + # Microseconds should be 456000 + cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')") - # Microseconds should be truncated to 123456 - cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')") + # Microseconds should be truncated to 123456 + cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')") - cur.execute("SELECT * FROM t") - with self.assertWarnsRegex(DeprecationWarning, "converter"): - values = [x[0] for x in cur.fetchall()] + cur.execute("SELECT * FROM t") + with self.assertWarnsRegex(DeprecationWarning, "converter"): + values = [x[0] for x in cur.fetchall()] - self.assertEqual(values, [ - datetime.datetime(2012, 4, 4, 15, 6, 0, 456000), - datetime.datetime(2012, 4, 4, 15, 6, 0, 123456), - ]) + self.assertEqual(values, [ + datetime.datetime(2012, 4, 4, 15, 6, 0, 456000), + datetime.datetime(2012, 4, 4, 15, 6, 0, 123456), + ]) def test_invalid_isolation_level_type(self): # isolation level is a string, not an integer - self.assertRaises(TypeError, - sqlite.connect, ":memory:", isolation_level=123) + regex = "isolation_level must be str or None" + with self.assertRaisesRegex(TypeError, regex): + memory_database(isolation_level=123).__enter__() def test_null_character(self): @@ -333,7 +324,7 @@ def test_null_character(self): cur.execute, query) def test_surrogates(self): - con = sqlite.connect(":memory:") + con = self.con self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'") self.assertRaises(UnicodeEncodeError, con, "select '\udcff'") cur = con.cursor() @@ -359,7 +350,7 @@ def test_commit_cursor_reset(self): to return rows multiple times when fetched from cursors after commit. See issues 10513 and 23129 for details. """ - con = sqlite.connect(":memory:") + con = self.con con.executescript(""" create table t(c); create table t2(c); @@ -391,10 +382,9 @@ def test_bpo31770(self): """ def callback(*args): pass - con = sqlite.connect(":memory:") - cur = sqlite.Cursor(con) + cur = sqlite.Cursor(self.con) ref = weakref.ref(cur, callback) - cur.__init__(con) + cur.__init__(self.con) del cur # The interpreter shouldn't crash when ref is collected. del ref @@ -425,6 +415,7 @@ def test_return_empty_bytestring(self): def test_table_lock_cursor_replace_stmt(self): with memory_database() as con: + con = self.con cur = con.cursor() cur.execute("create table t(t)") cur.executemany("insert into t values(?)", diff --git a/Lib/test/test_sqlite3/test_transactions.py b/Lib/test/test_sqlite3/test_transactions.py index 5d211dd47b0b6b..b7b231d2225852 100644 --- a/Lib/test/test_sqlite3/test_transactions.py +++ b/Lib/test/test_sqlite3/test_transactions.py @@ -28,7 +28,8 @@ from test.support.os_helper import TESTFN, unlink from test.support.script_helper import assert_python_ok -from test.test_sqlite3.test_dbapi import memory_database +from .util import memory_database +from .util import MemoryDatabaseMixin TIMEOUT = LOOPBACK_TIMEOUT / 10 @@ -132,14 +133,14 @@ def test_locking(self): def test_rollback_cursor_consistency(self): """Check that cursors behave correctly after rollback.""" - con = sqlite.connect(":memory:") - cur = con.cursor() - cur.execute("create table test(x)") - cur.execute("insert into test(x) values (5)") - cur.execute("select 1 union select 2 union select 3") + with memory_database() as con: + cur = con.cursor() + cur.execute("create table test(x)") + cur.execute("insert into test(x) values (5)") + cur.execute("select 1 union select 2 union select 3") - con.rollback() - self.assertEqual(cur.fetchall(), [(1,), (2,), (3,)]) + con.rollback() + self.assertEqual(cur.fetchall(), [(1,), (2,), (3,)]) def test_multiple_cursors_and_iternext(self): # gh-94028: statements are cleared and reset in cursor iternext. @@ -218,10 +219,7 @@ def test_no_duplicate_rows_after_rollback_new_query(self): -class SpecialCommandTests(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") - self.cur = self.con.cursor() +class SpecialCommandTests(MemoryDatabaseMixin, unittest.TestCase): def test_drop_table(self): self.cur.execute("create table test(i)") @@ -233,14 +231,8 @@ def test_pragma(self): self.cur.execute("insert into test(i) values (5)") self.cur.execute("pragma count_changes=1") - def tearDown(self): - self.cur.close() - self.con.close() - -class TransactionalDDL(unittest.TestCase): - def setUp(self): - self.con = sqlite.connect(":memory:") +class TransactionalDDL(MemoryDatabaseMixin, unittest.TestCase): def test_ddl_does_not_autostart_transaction(self): # For backwards compatibility reasons, DDL statements should not @@ -268,9 +260,6 @@ def test_transactional_ddl(self): with self.assertRaises(sqlite.OperationalError): self.con.execute("select * from test") - def tearDown(self): - self.con.close() - class IsolationLevelFromInit(unittest.TestCase): CREATE = "create table t(t)" diff --git a/Lib/test/test_sqlite3/test_userfunctions.py b/Lib/test/test_sqlite3/test_userfunctions.py index 05c2fb3aa6f8f2..5d12636dcd2b63 100644 --- a/Lib/test/test_sqlite3/test_userfunctions.py +++ b/Lib/test/test_sqlite3/test_userfunctions.py @@ -21,54 +21,15 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. -import contextlib -import functools -import io -import re import sys import unittest import sqlite3 as sqlite from unittest.mock import Mock, patch -from test.support import bigmemtest, catch_unraisable_exception, gc_collect - -from test.test_sqlite3.test_dbapi import cx_limit - - -def with_tracebacks(exc, regex="", name=""): - """Convenience decorator for testing callback tracebacks.""" - def decorator(func): - _regex = re.compile(regex) if regex else None - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - with catch_unraisable_exception() as cm: - # First, run the test with traceback enabled. - with check_tracebacks(self, cm, exc, _regex, name): - func(self, *args, **kwargs) - - # Then run the test with traceback disabled. - func(self, *args, **kwargs) - return wrapper - return decorator - - -@contextlib.contextmanager -def check_tracebacks(self, cm, exc, regex, obj_name): - """Convenience context manager for testing callback tracebacks.""" - sqlite.enable_callback_tracebacks(True) - try: - buf = io.StringIO() - with contextlib.redirect_stderr(buf): - yield - - self.assertEqual(cm.unraisable.exc_type, exc) - if regex: - msg = str(cm.unraisable.exc_value) - self.assertIsNotNone(regex.search(msg)) - if obj_name: - self.assertEqual(cm.unraisable.object.__name__, obj_name) - finally: - sqlite.enable_callback_tracebacks(False) +from test.support import bigmemtest, gc_collect + +from .util import cx_limit, memory_database +from .util import with_tracebacks, check_tracebacks def func_returntext(): @@ -405,19 +366,19 @@ def test_func_deterministic_keyword_only(self): def test_function_destructor_via_gc(self): # See bpo-44304: The destructor of the user function can # crash if is called without the GIL from the gc functions - dest = sqlite.connect(':memory:') def md5sum(t): return - dest.create_function("md5", 1, md5sum) - x = dest("create table lang (name, first_appeared)") - del md5sum, dest + with memory_database() as dest: + dest.create_function("md5", 1, md5sum) + x = dest("create table lang (name, first_appeared)") + del md5sum, dest - y = [x] - y.append(y) + y = [x] + y.append(y) - del x,y - gc_collect() + del x,y + gc_collect() @with_tracebacks(OverflowError) def test_func_return_too_large_int(self): @@ -514,6 +475,10 @@ def setUp(self): """ self.con.create_window_function("sumint", 1, WindowSumInt) + def tearDown(self): + self.cur.close() + self.con.close() + def test_win_sum_int(self): self.cur.execute(self.query % "sumint") self.assertEqual(self.cur.fetchall(), self.expected) @@ -634,6 +599,7 @@ def setUp(self): """) cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)", ("foo", 5, 3.14, None, memoryview(b"blob"),)) + cur.close() self.con.create_aggregate("nostep", 1, AggrNoStep) self.con.create_aggregate("nofinalize", 1, AggrNoFinalize) @@ -646,9 +612,7 @@ def setUp(self): self.con.create_aggregate("aggtxt", 1, AggrText) def tearDown(self): - #self.cur.close() - #self.con.close() - pass + self.con.close() def test_aggr_error_on_create(self): with self.assertRaises(sqlite.OperationalError): @@ -775,7 +739,7 @@ def setUp(self): self.con.set_authorizer(self.authorizer_cb) def tearDown(self): - pass + self.con.close() def test_table_access(self): with self.assertRaises(sqlite.DatabaseError) as cm: diff --git a/Lib/test/test_sqlite3/util.py b/Lib/test/test_sqlite3/util.py new file mode 100644 index 00000000000000..505406c437b632 --- /dev/null +++ b/Lib/test/test_sqlite3/util.py @@ -0,0 +1,78 @@ +import contextlib +import functools +import io +import re +import sqlite3 +import test.support +import unittest + + +# Helper for temporary memory databases +def memory_database(*args, **kwargs): + cx = sqlite3.connect(":memory:", *args, **kwargs) + return contextlib.closing(cx) + + +# Temporarily limit a database connection parameter +@contextlib.contextmanager +def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128): + try: + _prev = cx.setlimit(category, limit) + yield limit + finally: + cx.setlimit(category, _prev) + + +def with_tracebacks(exc, regex="", name=""): + """Convenience decorator for testing callback tracebacks.""" + def decorator(func): + _regex = re.compile(regex) if regex else None + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + with test.support.catch_unraisable_exception() as cm: + # First, run the test with traceback enabled. + with check_tracebacks(self, cm, exc, _regex, name): + func(self, *args, **kwargs) + + # Then run the test with traceback disabled. + func(self, *args, **kwargs) + return wrapper + return decorator + + +@contextlib.contextmanager +def check_tracebacks(self, cm, exc, regex, obj_name): + """Convenience context manager for testing callback tracebacks.""" + sqlite3.enable_callback_tracebacks(True) + try: + buf = io.StringIO() + with contextlib.redirect_stderr(buf): + yield + + self.assertEqual(cm.unraisable.exc_type, exc) + if regex: + msg = str(cm.unraisable.exc_value) + self.assertIsNotNone(regex.search(msg)) + if obj_name: + self.assertEqual(cm.unraisable.object.__name__, obj_name) + finally: + sqlite3.enable_callback_tracebacks(False) + + +class MemoryDatabaseMixin: + + def setUp(self): + self.con = sqlite3.connect(":memory:") + self.cur = self.con.cursor() + + def tearDown(self): + self.cur.close() + self.con.close() + + @property + def cx(self): + return self.con + + @property + def cu(self): + return self.cur