Skip to content

Commit

Permalink
Sync iterator wrapper has fewer possible states.
Browse files Browse the repository at this point in the history
We use an enum to thoroughly document the possible states the yielding
function can be in, making the code a lot more readable IMHO!
  • Loading branch information
aebrahim committed Oct 11, 2023
1 parent 9df9ecd commit ba26773
Showing 1 changed file with 52 additions and 67 deletions.
119 changes: 52 additions & 67 deletions once/_iterator_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import collections.abc
import enum
import threading

# Before we begin, a note on the assert statements in this file:
Expand Down Expand Up @@ -91,6 +92,15 @@ async def yield_results(self) -> collections.abc.AsyncGenerator:
i += 1


class _IteratorAction(enum.Enum):
# Generating the next value from the underlying iterator
GENERATING = 1
# Yield an already computed value
YIELDING = 2
# Waiting for the underlying iterator, already triggered from another call.
WAITING = 3


class GeneratorWrapper:
"""Wrapper around an sync generator which only runs once.
Expand All @@ -107,80 +117,55 @@ def __init__(self, func, *args, **kwargs) -> None:
self.next_send = None

def yield_results(self) -> collections.abc.Generator:
# Fast path for subsequent repeated call:
with self.lock:
finished = self.finished
if finished:
yield from self.results
return
i = 0
yield_value = None
next_send = None
# Fast path for subsequent calls will not require a lock
while True:
# If we on before the penultimate entry, we can return now. When yielding the last
# element of results, we need to be recording next_send, so that needs the lock.
if i < len(self.results) - 1:
yield self.results[i]
i += 1
continue
# Because we don't hold a lock here, we can't make this assumption
# i == len(self.results) - 1 or i == len(self.results)
# because the iterator could have moved in the interim. However, it will no longer
# move once self.finished.
if self.finished:
if i < len(self.results):
yield self.results[i]
i += 1
continue
if i == len(self.results):
return

# Initial calls, and concurrent calls before completion will require the lock.
action: _IteratorAction | None = None
# With a lock, we figure out which action to take, and then we take it after release.
with self.lock:
# Just in case a race condition prevented us from hitting these conditions before,
# check them again, so they can be handled by the code before the lock.
if i < len(self.results) - 1:
continue
if self.finished:
if i < len(self.results):
continue
if i == len(self.results):
if i == len(self.results):
if self.finished:
return
assert i == len(self.results) - 1 or i == len(self.results)
# If we are at the end and waiting for the generator to complete, there is nothing
# to do!
if self.generating and i == len(self.results):
continue

# At this point, there are 2 states to handle, which we will want to do outside the
# lock to avoid deadlocks.
# State #1: We are about to yield back the last entry in self.results and potentially
# log next send. We can allow multiple calls to enter this state, as long
# as we re-grab the lock before modifying self.next_send
# State #2: We are at the end of self.results, and need to call our underlying
# iterator. Only one call may enter this state due to our check of
# self.generating above.
if i == len(self.results) and not self.generating:
self.generating = True
next_send = self.next_send
listening = False
if self.generating:
action = _IteratorAction.WAITING
else:
action = _IteratorAction.GENERATING
next_send = self.next_send
else:
assert i == len(self.results) - 1 or self.generating
listening = True
# We break outside the lock to either listen or kick off a new generation.
if listening:
next_send = yield self.results[i]
action = _IteratorAction.YIELDING
yield_value = self.results[i]
if action == _IteratorAction.WAITING:
continue
if action == _IteratorAction.YIELDING:
next_send = yield yield_value
i += 1
# If this is the last element and we have not yet kicked off the next iteration,
# we need to record the next send value.
with self.lock:
if not self.finished and i == len(self.results):
if i == len(self.results) and not self.generating:
self.next_send = next_send
continue
# We must be in generating state
assert self.generator is not None
try:
result = self.generator.send(next_send)
except StopIteration:
# This lock should be unnecessary, which by definition means there should be no
# contention on it, so we use it to preserve our assumptions about variables which
# are modified under lock.
with self.lock:
self.finished = True
self.generator = None # Allow this to be GCed.
self.generating = False
return
with self.lock:
self.results.append(result)
self.generating = False
if action == _IteratorAction.GENERATING:
assert self.generator is not None
try:
result = self.generator.send(next_send)
except StopIteration:
# This lock should be unnecessary, which by definition means there should be no
# contention on it, so we use it to preserve our assumptions about variables which
# are modified under lock.
with self.lock:
self.finished = True
self.generator = None # Allow this to be GCed.
self.generating = False
else:
with self.lock:
self.generating = False
self.results.append(result)

0 comments on commit ba26773

Please sign in to comment.