Skip to content

Commit

Permalink
pythongh-123089: Make weakref.WeakSet safe against concurrent mutatio…
Browse files Browse the repository at this point in the history
…ns while it is being iterated (python#123279)

* Make `weakref.WeakSet` safe against concurrent mutations while it is being iterated.

`_IterationGuard` is no longer used for `WeakSet`, it now relies on copying the underlying set which is an atomic operation while iterating so that it can be modified by other threads.
  • Loading branch information
kumaraditya303 authored and ryan-duve committed Aug 27, 2024
1 parent 731401a commit 9729ae8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 43 deletions.
53 changes: 10 additions & 43 deletions Lib/_weakrefset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,41 +36,26 @@ def __exit__(self, e, t, b):
class WeakSet:
def __init__(self, data=None):
self.data = set()

def _remove(item, selfref=ref(self)):
self = selfref()
if self is not None:
if self._iterating:
self._pending_removals.append(item)
else:
self.data.discard(item)
self.data.discard(item)

self._remove = _remove
# A list of keys to be removed
self._pending_removals = []
self._iterating = set()
if data is not None:
self.update(data)

def _commit_removals(self):
pop = self._pending_removals.pop
discard = self.data.discard
while True:
try:
item = pop()
except IndexError:
return
discard(item)

def __iter__(self):
with _IterationGuard(self):
for itemref in self.data:
item = itemref()
if item is not None:
# Caveat: the iterator will keep a strong reference to
# `item` until it is resumed or closed.
yield item
for itemref in self.data.copy():
item = itemref()
if item is not None:
# Caveat: the iterator will keep a strong reference to
# `item` until it is resumed or closed.
yield item

def __len__(self):
return len(self.data) - len(self._pending_removals)
return len(self.data)

def __contains__(self, item):
try:
Expand All @@ -83,21 +68,15 @@ def __reduce__(self):
return self.__class__, (list(self),), self.__getstate__()

def add(self, item):
if self._pending_removals:
self._commit_removals()
self.data.add(ref(item, self._remove))

def clear(self):
if self._pending_removals:
self._commit_removals()
self.data.clear()

def copy(self):
return self.__class__(self)

def pop(self):
if self._pending_removals:
self._commit_removals()
while True:
try:
itemref = self.data.pop()
Expand All @@ -108,18 +87,12 @@ def pop(self):
return item

def remove(self, item):
if self._pending_removals:
self._commit_removals()
self.data.remove(ref(item))

def discard(self, item):
if self._pending_removals:
self._commit_removals()
self.data.discard(ref(item))

def update(self, other):
if self._pending_removals:
self._commit_removals()
for element in other:
self.add(element)

Expand All @@ -136,8 +109,6 @@ def difference(self, other):
def difference_update(self, other):
self.__isub__(other)
def __isub__(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
Expand All @@ -151,8 +122,6 @@ def intersection(self, other):
def intersection_update(self, other):
self.__iand__(other)
def __iand__(self, other):
if self._pending_removals:
self._commit_removals()
self.data.intersection_update(ref(item) for item in other)
return self

Expand Down Expand Up @@ -184,8 +153,6 @@ def symmetric_difference(self, other):
def symmetric_difference_update(self, other):
self.__ixor__(other)
def __ixor__(self, other):
if self._pending_removals:
self._commit_removals()
if self is other:
self.data.clear()
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make :class:`weakref.WeakSet` safe against concurrent mutations while it is being iterated. Patch by Kumar Aditya.

0 comments on commit 9729ae8

Please sign in to comment.