Skip to content

Commit

Permalink
Fix itertools.count in free-threading mode
Browse files Browse the repository at this point in the history
Thread safety in count_next. count_next has two modes. slow mode
(obj->cnt set to PY_SSIZE_T_MAX), which now uses the object mutex
(only if GIL is disabled) and fast mode, which is either simple
cnt++  if GIL is enabled, or uses atomic_compare_exchange if
GIL is disabled.
  • Loading branch information
Arnon Yaari committed May 20, 2024
1 parent bf17986 commit 0813cc2
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 14 deletions.
20 changes: 20 additions & 0 deletions Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,26 @@ def test_count_with_stride(self):
self.assertEqual(type(next(c)), int)
self.assertEqual(type(next(c)), float)

def test_count_threading(self, step=1):
# this test verifies multithreading consistency, which is
# mostly for testing builds without GIL, but nice to test anyway
count_to = 10_000
num_threads = 10
c = count(step=step)
def counting_thread():
for i in range(count_to):
next(c)
threads = []
for i in range(num_threads):
thread = threading.Thread(target=counting_thread)
thread.start()
threads.append(thread)
[thread.join() for thread in threads]
self.assertEqual(next(c), count_to * num_threads * step)

def test_count_with_stride_threading(self):
self.test_count_threading(5)

def test_cycle(self):
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
self.assertEqual(list(cycle('')), [])
Expand Down
53 changes: 40 additions & 13 deletions Modules/itertoolsmodule.c
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#include "Python.h"
#include "pycore_call.h" // _PyObject_CallNoArgs()
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
#include "pycore_long.h" // _PyLong_GetZero()
#include "pycore_moduleobject.h" // _PyModule_GetState()
#include "pycore_typeobject.h" // _PyType_GetModuleState()
#include "pycore_object.h" // _PyObject_GC_TRACK()
#include "pycore_tuple.h" // _PyTuple_ITEMS()
#include "pycore_call.h" // _PyObject_CallNoArgs()
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
#include "pycore_long.h" // _PyLong_GetZero()
#include "pycore_moduleobject.h" // _PyModule_GetState()
#include "pycore_typeobject.h" // _PyType_GetModuleState()
#include "pycore_object.h" // _PyObject_GC_TRACK()
#include "pycore_tuple.h" // _PyTuple_ITEMS()

#include <stddef.h> // offsetof()
#include <stddef.h> // offsetof()

/* Itertools module written and maintained
by Raymond D. Hettinger <[email protected]>
Expand Down Expand Up @@ -3254,7 +3255,7 @@ fast_mode: when cnt an integer < PY_SSIZE_T_MAX and no step is specified.
assert(cnt != PY_SSIZE_T_MAX && long_cnt == NULL && long_step==PyLong(1));
Advances with: cnt += 1
When count hits Y_SSIZE_T_MAX, switch to slow_mode.
When count hits PY_SSIZE_T_MAX, switch to slow_mode.
slow_mode: when cnt == PY_SSIZE_T_MAX, step is not int(1), or cnt is a float.
Expand Down Expand Up @@ -3386,7 +3387,7 @@ count_nextlong(countobject *lz)

long_cnt = lz->long_cnt;
if (long_cnt == NULL) {
/* Switch to slow_mode */
/* Switching from fast mode */
long_cnt = PyLong_FromSsize_t(PY_SSIZE_T_MAX);
if (long_cnt == NULL)
return NULL;
Expand All @@ -3403,9 +3404,35 @@ count_nextlong(countobject *lz)
static PyObject *
count_next(countobject *lz)
{
if (lz->cnt == PY_SSIZE_T_MAX)
return count_nextlong(lz);
return PyLong_FromSsize_t(lz->cnt++);
PyObject *returned;
Py_ssize_t cnt;

cnt = FT_ATOMIC_LOAD_SSIZE_RELAXED(lz->cnt);
for (;;) {
if (cnt == PY_SSIZE_T_MAX) {
/* slow mode */
Py_BEGIN_CRITICAL_SECTION(lz);
returned = count_nextlong(lz);
Py_END_CRITICAL_SECTION();
return returned;
}
#ifdef Py_GIL_DISABLED
/* thread-safe fast version (increment by one).
* If lz->cnt changed between the pervious read and now,
* that means another thread got in our way. In this case,
* update cnt to new value of lz->cnt, and try again.
* Otherwise, (no other thread updated lz->cnt),
* atomically update lz->cnt with the incremented value and
* then return cnt (the previous value)
*/
if (_Py_atomic_compare_exchange_ssize(&lz->cnt, &cnt, cnt + 1)) {
return PyLong_FromSsize_t(cnt);
}
#else
/* fast mode when GIL is enabled */
return PyLong_FromSsize_t(lz->cnt++);
#endif
}
}

static PyObject *
Expand Down
1 change: 0 additions & 1 deletion Tools/tsan/suppressions_free_threading.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ race_top:_Py_dict_lookup_threadsafe
race_top:_imp_release_lock
race_top:_multiprocessing_SemLock_acquire_impl
race_top:builtin_compile_impl
race_top:count_next
race_top:dictiter_new
race_top:dictresize
race_top:insert_to_emptydict
Expand Down

0 comments on commit 0813cc2

Please sign in to comment.