Skip to content

Commit

Permalink
Do not hide yield
Browse files Browse the repository at this point in the history
  • Loading branch information
SvenDS9 committed Mar 9, 2023
1 parent ffab0c1 commit 4687962
Showing 1 changed file with 8 additions and 20 deletions.
28 changes: 8 additions & 20 deletions torchdata/datapipes/iter/transform/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def _apply_fn(self, data):
return tuple(data) if t_flag else data

def __iter__(self) -> Iterator[T_co]:
try:
with futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) as executor:
executor = futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs)
futures_deque: deque = deque()
has_next = True
Expand All @@ -764,25 +764,13 @@ def __iter__(self) -> Iterator[T_co]:
has_next = False
break

# Yield must be hidden in closure so that the futures are submitted
# before the first iterator value is required.
def result_iterator():
try:
while len(futures_deque) > 0:
nonlocal has_next
if has_next:
try:
futures_deque.append(executor.submit(self._apply_fn, next(itr)))
except StopIteration:
has_next = False
yield futures_deque.popleft().result()
finally:
executor.shutdown()

return result_iterator()
except Exception:
executor.shutdown()
raise
while len(futures_deque) > 0:
if has_next:
try:
futures_deque.append(executor.submit(self._apply_fn, next(itr)))
except StopIteration:
has_next = False
yield futures_deque.popleft().result()

def __len__(self) -> int:
if isinstance(self.datapipe, Sized):
Expand Down

0 comments on commit 4687962

Please sign in to comment.