diff --git a/django/test/testcases.py b/django/test/testcases.py index b5d426f75fe7..7382f7f0f096 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -10,6 +10,7 @@ from copy import copy, deepcopy from difflib import get_close_matches from functools import wraps +from unittest import mock from unittest.suite import _DebugResult from unittest.util import safe_repr from urllib.parse import ( @@ -37,6 +38,7 @@ from django.core.servers.basehttp import ThreadedWSGIServer, WSGIRequestHandler from django.core.signals import setting_changed from django.db import DEFAULT_DB_ALIAS, connection, connections, transaction +from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper from django.forms.fields import CharField from django.http import QueryDict from django.http.request import split_domain_port, validate_host @@ -255,6 +257,13 @@ def _add_databases_failures(cls): } method = getattr(connection, name) setattr(connection, name, _DatabaseFailure(method, message)) + cls.enterClassContext( + mock.patch.object( + BaseDatabaseWrapper, + "ensure_connection", + new=cls.ensure_connection_patch_method(), + ) + ) @classmethod def _remove_databases_failures(cls): @@ -266,6 +275,28 @@ def _remove_databases_failures(cls): method = getattr(connection, name) setattr(connection, name, method.wrapped) + @classmethod + def ensure_connection_patch_method(cls): + real_ensure_connection = BaseDatabaseWrapper.ensure_connection + + def patched_ensure_connection(self, *args, **kwargs): + if ( + self.connection is None + and self.alias not in cls.databases + and self.alias != NO_DB_ALIAS + ): + # Connection has not yet been established, but the alias is not allowed. + message = cls._disallowed_database_msg % { + "test": f"{cls.__module__}.{cls.__qualname__}", + "alias": self.alias, + "operation": "threaded connections", + } + return _DatabaseFailure(self.ensure_connection, message)() + + real_ensure_connection(self, *args, **kwargs) + + return patched_ensure_connection + def __call__(self, result=None): """ Wrapper around default __call__ method to perform common Django test diff --git a/docs/releases/5.1.txt b/docs/releases/5.1.txt index f949b31ad25b..544b1f5d0855 100644 --- a/docs/releases/5.1.txt +++ b/docs/releases/5.1.txt @@ -250,6 +250,9 @@ Tests * The new :meth:`.SimpleTestCase.assertNotInHTML` assertion allows testing that an HTML fragment is not contained in the given HTML haystack. +* In order to enforce test isolation, database connections inside threads are + no longer allowed in :class:`~django.test.SimpleTestCase`. + URLs ~~~~ diff --git a/tests/test_utils/tests.py b/tests/test_utils/tests.py index ce78ffc0084b..65a782bf87b6 100644 --- a/tests/test_utils/tests.py +++ b/tests/test_utils/tests.py @@ -1,5 +1,6 @@ import os import sys +import threading import unittest import warnings from io import StringIO @@ -2093,6 +2094,29 @@ def test_disallowed_database_chunked_cursor_queries(self): with self.assertRaisesMessage(DatabaseOperationForbidden, expected_message): next(Car.objects.iterator()) + def test_disallowed_thread_database_connection(self): + expected_message = ( + "Database threaded connections to 'default' are not allowed in " + "SimpleTestCase subclasses. Either subclass TestCase or TransactionTestCase" + " to ensure proper test isolation or add 'default' to " + "test_utils.tests.DisallowedDatabaseQueriesTests.databases to " + "silence this failure." + ) + + exceptions = [] + + def thread_func(): + try: + Car.objects.first() + except DatabaseOperationForbidden as e: + exceptions.append(e) + + t = threading.Thread(target=thread_func) + t.start() + t.join() + self.assertEqual(len(exceptions), 1) + self.assertEqual(exceptions[0].args[0], expected_message) + class AllowedDatabaseQueriesTests(SimpleTestCase): databases = {"default"} @@ -2103,6 +2127,14 @@ def test_allowed_database_queries(self): def test_allowed_database_chunked_cursor_queries(self): next(Car.objects.iterator(), None) + def test_allowed_threaded_database_queries(self): + def thread_func(): + next(Car.objects.iterator(), None) + + t = threading.Thread(target=thread_func) + t.start() + t.join() + class DatabaseAliasTests(SimpleTestCase): def setUp(self):