Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-117657: Fix itertools.count thread safety #119267

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,26 @@ def test_count_with_stride(self):
self.assertEqual(type(next(c)), int)
self.assertEqual(type(next(c)), float)

def test_count_threading(self, step=1):
# this test verifies multithreading consistency, which is
# mostly for testing builds without GIL, but nice to test anyway
count_to = 10_000
num_threads = 10
c = count(step=step)
def counting_thread():
for i in range(count_to):
next(c)
threads = []
for i in range(num_threads):
thread = threading.Thread(target=counting_thread)
thread.start()
threads.append(thread)
[thread.join() for thread in threads]
self.assertEqual(next(c), count_to * num_threads * step)

def test_count_with_stride_threading(self):
self.test_count_threading(5)

def test_cycle(self):
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
self.assertEqual(list(cycle('')), [])
Expand Down
53 changes: 40 additions & 13 deletions Modules/itertoolsmodule.c
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#include "Python.h"
#include "pycore_call.h" // _PyObject_CallNoArgs()
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
#include "pycore_long.h" // _PyLong_GetZero()
#include "pycore_moduleobject.h" // _PyModule_GetState()
#include "pycore_typeobject.h" // _PyType_GetModuleState()
#include "pycore_object.h" // _PyObject_GC_TRACK()
#include "pycore_tuple.h" // _PyTuple_ITEMS()
#include "pycore_call.h" // _PyObject_CallNoArgs()
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
#include "pycore_long.h" // _PyLong_GetZero()
#include "pycore_moduleobject.h" // _PyModule_GetState()
#include "pycore_typeobject.h" // _PyType_GetModuleState()
#include "pycore_object.h" // _PyObject_GC_TRACK()
#include "pycore_tuple.h" // _PyTuple_ITEMS()

#include <stddef.h> // offsetof()
#include <stddef.h> // offsetof()

/* Itertools module written and maintained
by Raymond D. Hettinger <[email protected]>
Expand Down Expand Up @@ -3254,7 +3255,7 @@ fast_mode: when cnt an integer < PY_SSIZE_T_MAX and no step is specified.

assert(cnt != PY_SSIZE_T_MAX && long_cnt == NULL && long_step==PyLong(1));
Advances with: cnt += 1
When count hits Y_SSIZE_T_MAX, switch to slow_mode.
When count hits PY_SSIZE_T_MAX, switch to slow_mode.

slow_mode: when cnt == PY_SSIZE_T_MAX, step is not int(1), or cnt is a float.

Expand Down Expand Up @@ -3386,7 +3387,7 @@ count_nextlong(countobject *lz)

long_cnt = lz->long_cnt;
if (long_cnt == NULL) {
/* Switch to slow_mode */
/* Switching from fast mode */
long_cnt = PyLong_FromSsize_t(PY_SSIZE_T_MAX);
if (long_cnt == NULL)
return NULL;
Expand All @@ -3403,9 +3404,35 @@ count_nextlong(countobject *lz)
static PyObject *
count_next(countobject *lz)
{
if (lz->cnt == PY_SSIZE_T_MAX)
return count_nextlong(lz);
return PyLong_FromSsize_t(lz->cnt++);
PyObject *returned;
Py_ssize_t cnt;

cnt = FT_ATOMIC_LOAD_SSIZE_RELAXED(lz->cnt);
for (;;) {
if (cnt == PY_SSIZE_T_MAX) {
/* slow mode */
Py_BEGIN_CRITICAL_SECTION(lz);
returned = count_nextlong(lz);
Py_END_CRITICAL_SECTION();
return returned;
}
#ifdef Py_GIL_DISABLED
/* thread-safe fast version (increment by one).
* If lz->cnt changed between the pervious read and now,
* that means another thread got in our way. In this case,
* update cnt to new value of lz->cnt, and try again.
* Otherwise, (no other thread updated lz->cnt),
* atomically update lz->cnt with the incremented value and
* then return cnt (the previous value)
*/
if (_Py_atomic_compare_exchange_ssize(&lz->cnt, &cnt, cnt + 1)) {
return PyLong_FromSsize_t(cnt);
}
#else
/* fast mode when GIL is enabled */
return PyLong_FromSsize_t(lz->cnt++);
#endif
}
}

static PyObject *
Expand Down
1 change: 0 additions & 1 deletion Tools/tsan/suppressions_free_threading.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ race_top:_Py_dict_lookup_threadsafe
race_top:_imp_release_lock
race_top:_multiprocessing_SemLock_acquire_impl
race_top:builtin_compile_impl
race_top:count_next
race_top:dictiter_new
race_top:dictresize
race_top:insert_to_emptydict
Expand Down
Loading