From 71da36ad6593261ad5c528a4c4f8192db2182517 Mon Sep 17 00:00:00 2001 From: Laurie O Date: Tue, 20 Feb 2024 22:40:07 +1000 Subject: [PATCH] WIP: fix threading queue shutdown all-methods-many-threads test (#115258) --- Lib/test/test_queue.py | 197 ++++++++++++++++++++++++----------------- 1 file changed, 118 insertions(+), 79 deletions(-) diff --git a/Lib/test/test_queue.py b/Lib/test/test_queue.py index 92d670ca6f8f5b..0a8afd045de268 100644 --- a/Lib/test/test_queue.py +++ b/Lib/test/test_queue.py @@ -317,97 +317,136 @@ def test_shutdown_all_methods_in_one_thread(self): def test_shutdown_immediate_all_methods_in_one_thread(self): return self._shutdown_all_methods_in_one_thread(True) - def _write_msg_thread(self, q, n, results, delay, - i_when_exec_shutdown, - event_start, event_end): - event_start.wait() - for i in range(1, n+1): - try: - q.put((i, "YDLO")) - results.append(True) - except self.queue.ShutDown: - results.append(False) - # triggers shutdown of queue - if i == i_when_exec_shutdown: - event_end.set() - time.sleep(delay) - # end of all puts - q.join() + def _shutdown_all_methods_in_many_threads(self, immediate): + # Arrange + q = self.type2test() + + start_puts = threading.Event() + start_gets = threading.Event() + put = threading.Event() + shutdown = threading.Event() + + n_gets_lock = threading.Lock() + n_gets = 0 - def _read_msg_thread(self, q, nb, results, delay, event_start): - event_start.wait() - block = True - while nb: - time.sleep(delay) + calls = [] + results = [] + queue_size_after_join = [] + + def _record_call(f, *a): + calls.append((f, a)) + return f(*a) + + def _record_result(f): try: - # Get at least one message - q.get(block) - block = False - q.task_done() - results.append(True) - nb -= 1 - except self.queue.ShutDown: - results.append(False) - nb -= 1 - except self.queue.Empty: - pass - q.join() + result = f() + except Exception as e: + results.append((f, e)) + else: + results.append((f, result)) - def _shutdown_thread(self, q, event_end, immediate): - event_end.wait() - q.shutdown(immediate) - q.join() + def put_worker(): + start_puts.wait() - def _join_thread(self, q, delay, event_start): - event_start.wait() - time.sleep(delay) - q.join() + for i in range(5): + _record_call(q.put, i) - def _shutdown_all_methods_in_many_threads(self, immediate): - q = self.type2test() - ps = [] - ev_start = threading.Event() - ev_exec_shutdown = threading.Event() - res_puts = [] - res_gets = [] - delay = 1e-4 - read_process = 4 - nb_msgs = read_process * 16 - nb_msgs_r = nb_msgs // read_process - when_exec_shutdown = nb_msgs // 2 - lprocs = ( - (self._write_msg_thread, 1, (q, nb_msgs, res_puts, delay, - when_exec_shutdown, - ev_start, ev_exec_shutdown)), - (self._read_msg_thread, read_process, (q, nb_msgs_r, - res_gets, delay*2, - ev_start)), - (self._join_thread, 2, (q, delay*2, ev_start)), - (self._shutdown_thread, 1, (q, ev_exec_shutdown, immediate)), - ) - # start all threds - for func, n, args in lprocs: - for i in range(n): - ps.append(threading.Thread(target=func, args=args)) - ps[-1].start() - # set event in order to run q.shutdown() - ev_start.set() + start_gets.set() - if not immediate: - assert(len(res_gets) == len(res_puts)) - assert(res_gets.count(True) == res_puts.count(True)) - else: - assert(len(res_gets) <= len(res_puts)) - assert(res_gets.count(True) <= res_puts.count(True)) + for i in range(5, 25): + put.wait() + _record_call(q.put, i) + put.clear() + + shutdown.set() + + # Should raise ShutDown + put.wait() + _record_call(q.put, 25) + + def get_worker(): + nonlocal n_gets + + start_gets.wait() - for thread in ps[1:]: + while True: + with n_gets_lock: + if n_gets >= 25: + break + n_gets += 1 + + put.set() + _record_call(q.get, False) + + put.set() + _record_call(q.get, False) # should raise ShutDown if immediate + + def join_worker(): + start_gets.wait() + _record_call(q.join) + queue_size_after_join.append(q.qsize()) + + def shutdown_worker(): + shutdown.wait() + _record_call(q.shutdown, immediate) + + def _start_thread(f): + thread = threading.Thread(target=_record_result, args=(f,)) + thread.start() + return thread + + threads = [ + _start_thread(put_worker), + *(_start_thread(get_worker) for _ in range(4)), + *(_start_thread(join_worker) for _ in range(2)), + _start_thread(shutdown_worker), + ] + + # Act + start_puts.set() + shutdown.wait() + for thread in threads: thread.join() - @unittest.skip("test times out (gh-115258)") + # Assert + self.assertEqual(q.qsize(), 0) + + if immediate: + self.assertTrue(all(qs > 0 for qs in queue_size_after_join)) + else: + self.assertTrue(all(qs == 0 for qs in queue_size_after_join)) + + self.assertListEqual( + [a for f, a in calls if f is q.put], [(i,) for i in range(33)] + ) + self.assertListEqual( + [a for f, a in calls if f is q.get], [(False,)] * 36 + ) + self.assertListEqual([a for f, a in calls if f is q.join], [(), ()]) + self.assertListEqual( + [a for f, a in calls if f is q.shutdown], [immediate] + ) + + put_worker_result = next(r for f, r in results if f is put_worker) + self.assertIs(put_worker_result.__class__, self.queue.ShutDown) + + get_worker_results = [r for f, r in results if f is get_worker] + if immediate: + self.assertListEqual(get_worker_results, [self.queue.ShutDown] * 4) + else: + self.assertListEqual(get_worker_results, [None] * 4) + + join_worker_results = [r for f, r in results if f is join_worker] + self.assertListEqual(join_worker_results, [None, None]) + + shutdown_worker_result = next( + r for f, r in results if f is shutdown_worker + ) + self.assertIsNone(shutdown_worker_result, None) + def test_shutdown_all_methods_in_many_threads(self): return self._shutdown_all_methods_in_many_threads(False) - @unittest.skip("test times out (gh-115258)") def test_shutdown_immediate_all_methods_in_many_threads(self): return self._shutdown_all_methods_in_many_threads(True)