From eb22e2b251002b65f3b93e67c990c21e1151b25d Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Mon, 4 Mar 2024 13:59:30 -0700 Subject: [PATCH] gh-115490: Make the interpreter.channels and interpreter.queues Modules Handle Reloading Properly (gh-115493) The problem manifested when the .py module got reloaded and the corresponding extension module didn't. The .py module registers types with the extension and the extension was not allowing that to happen more than once. The solution: let it happen more than once. --- Lib/test/libregrtest/main.py | 9 ++------- Lib/test/test__xxinterpchannels.py | 5 +++++ Lib/test/test_interpreters/test_channels.py | 19 +++++++++++++++++++ Lib/test/test_interpreters/test_queues.py | 15 +++++++++++++++ Modules/_xxinterpchannelsmodule.c | 15 ++++++++++----- Modules/_xxinterpqueuesmodule.c | 16 +++++++++------- 6 files changed, 60 insertions(+), 19 deletions(-) diff --git a/Lib/test/libregrtest/main.py b/Lib/test/libregrtest/main.py index 6fbe3d00981ddd..126daca388fd7f 100644 --- a/Lib/test/libregrtest/main.py +++ b/Lib/test/libregrtest/main.py @@ -384,15 +384,10 @@ def run_tests_sequentially(self, runtests) -> None: result = self.run_test(test_name, runtests, tracer) - # Unload the newly imported test modules (best effort - # finalization). To work around gh-115490, don't unload - # test.support.interpreters and its submodules even if they - # weren't loaded before. - keep = "test.support.interpreters" + # Unload the newly imported test modules (best effort finalization) new_modules = [module for module in sys.modules if module not in save_modules and - module.startswith(("test.", "test_")) - and not module.startswith(keep)] + module.startswith(("test.", "test_"))] for module in new_modules: sys.modules.pop(module, None) # Remove the attribute of the parent module. diff --git a/Lib/test/test__xxinterpchannels.py b/Lib/test/test__xxinterpchannels.py index cc2ed7849b0c0f..c5d29bd2dd911f 100644 --- a/Lib/test/test__xxinterpchannels.py +++ b/Lib/test/test__xxinterpchannels.py @@ -18,6 +18,11 @@ channels = import_helper.import_module('_xxinterpchannels') +# Additional tests are found in Lib/test/test_interpreters/test_channels.py. +# New tests should be added there. +# XXX The tests here should be moved there. See the note under LowLevelTests. + + ################################## # helpers diff --git a/Lib/test/test_interpreters/test_channels.py b/Lib/test/test_interpreters/test_channels.py index 07e503837bcf75..57204e2776468d 100644 --- a/Lib/test/test_interpreters/test_channels.py +++ b/Lib/test/test_interpreters/test_channels.py @@ -1,3 +1,4 @@ +import importlib import threading from textwrap import dedent import unittest @@ -11,6 +12,24 @@ from .utils import _run_output, TestBase +class LowLevelTests(TestBase): + + # The behaviors in the low-level module is important in as much + # as it is exercised by the high-level module. Therefore the + # most # important testing happens in the high-level tests. + # These low-level tests cover corner cases that are not + # encountered by the high-level module, thus they + # mostly shouldn't matter as much. + + # Additional tests are found in Lib/test/test__xxinterpchannels.py. + # XXX Those should be either moved to LowLevelTests or eliminated + # in favor of high-level tests in this file. + + def test_highlevel_reloaded(self): + # See gh-115490 (https://github.com/python/cpython/issues/115490). + importlib.reload(channels) + + class TestChannels(TestBase): def test_create(self): diff --git a/Lib/test/test_interpreters/test_queues.py b/Lib/test/test_interpreters/test_queues.py index 905ef035d43feb..0a1fdb41f73166 100644 --- a/Lib/test/test_interpreters/test_queues.py +++ b/Lib/test/test_interpreters/test_queues.py @@ -1,3 +1,4 @@ +import importlib import threading from textwrap import dedent import unittest @@ -20,6 +21,20 @@ def tearDown(self): pass +class LowLevelTests(TestBase): + + # The behaviors in the low-level module is important in as much + # as it is exercised by the high-level module. Therefore the + # most # important testing happens in the high-level tests. + # These low-level tests cover corner cases that are not + # encountered by the high-level module, thus they + # mostly shouldn't matter as much. + + def test_highlevel_reloaded(self): + # See gh-115490 (https://github.com/python/cpython/issues/115490). + importlib.reload(queues) + + class QueueTests(TestBase): def test_create(self): diff --git a/Modules/_xxinterpchannelsmodule.c b/Modules/_xxinterpchannelsmodule.c index c0d4fd08088434..0ad184a78e8c1a 100644 --- a/Modules/_xxinterpchannelsmodule.c +++ b/Modules/_xxinterpchannelsmodule.c @@ -2675,12 +2675,17 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv) return -1; } - if (state->send_channel_type != NULL - || state->recv_channel_type != NULL) - { - PyErr_SetString(PyExc_TypeError, "already registered"); - return -1; + // Clear the old values if the .py module was reloaded. + if (state->send_channel_type != NULL) { + (void)_PyCrossInterpreterData_UnregisterClass(state->send_channel_type); + Py_CLEAR(state->send_channel_type); } + if (state->recv_channel_type != NULL) { + (void)_PyCrossInterpreterData_UnregisterClass(state->recv_channel_type); + Py_CLEAR(state->recv_channel_type); + } + + // Add and register the types. state->send_channel_type = (PyTypeObject *)Py_NewRef(send); state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv); if (ensure_xid_class(send, _channelend_shared) < 0) { diff --git a/Modules/_xxinterpqueuesmodule.c b/Modules/_xxinterpqueuesmodule.c index e35d1699cfea89..1b76b6963ae0f1 100644 --- a/Modules/_xxinterpqueuesmodule.c +++ b/Modules/_xxinterpqueuesmodule.c @@ -169,6 +169,9 @@ static int clear_module_state(module_state *state) { /* external types */ + if (state->queue_type != NULL) { + (void)_PyCrossInterpreterData_UnregisterClass(state->queue_type); + } Py_CLEAR(state->queue_type); /* QueueError */ @@ -1078,15 +1081,18 @@ set_external_queue_type(PyObject *module, PyTypeObject *queue_type) { module_state *state = get_module_state(module); + // Clear the old value if the .py module was reloaded. if (state->queue_type != NULL) { - PyErr_SetString(PyExc_TypeError, "already registered"); - return -1; + (void)_PyCrossInterpreterData_UnregisterClass( + state->queue_type); + Py_CLEAR(state->queue_type); } - state->queue_type = (PyTypeObject *)Py_NewRef(queue_type); + // Add and register the new type. if (ensure_xid_class(queue_type, _queueobj_shared) < 0) { return -1; } + state->queue_type = (PyTypeObject *)Py_NewRef(queue_type); return 0; } @@ -1703,10 +1709,6 @@ module_clear(PyObject *mod) { module_state *state = get_module_state(mod); - if (state->queue_type != NULL) { - (void)_PyCrossInterpreterData_UnregisterClass(state->queue_type); - } - // Now we clear the module state. clear_module_state(state); return 0;