Skip to content

Commit

Permalink
Ensure shutdown of executor
Browse files Browse the repository at this point in the history
  • Loading branch information
SvenDS9 committed Mar 6, 2023
1 parent ce317fd commit ffab0c1
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions torchdata/datapipes/iter/transform/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,31 +752,37 @@ def _apply_fn(self, data):
return tuple(data) if t_flag else data

def __iter__(self) -> Iterator[T_co]:
executor = futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs)
futures_deque: deque = deque()
has_next = True
itr = iter(self.datapipe)
for _ in range(self.scheduled_tasks):
try:
futures_deque.append(executor.submit(self._apply_fn, next(itr)))
except StopIteration:
has_next = False
break

# Yield must be hidden in closure so that the futures are submitted
# before the first iterator value is required.
def result_iterator(executor):
while len(futures_deque) > 0:
nonlocal has_next
if has_next:
try:
futures_deque.append(executor.submit(self._apply_fn, next(itr)))
except StopIteration:
has_next = False
yield futures_deque.popleft().result()
try:
executor = futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs)
futures_deque: deque = deque()
has_next = True
itr = iter(self.datapipe)
for _ in range(self.scheduled_tasks):
try:
futures_deque.append(executor.submit(self._apply_fn, next(itr)))
except StopIteration:
has_next = False
break

# Yield must be hidden in closure so that the futures are submitted
# before the first iterator value is required.
def result_iterator():
try:
while len(futures_deque) > 0:
nonlocal has_next
if has_next:
try:
futures_deque.append(executor.submit(self._apply_fn, next(itr)))
except StopIteration:
has_next = False
yield futures_deque.popleft().result()
finally:
executor.shutdown()

return result_iterator()
except Exception:
executor.shutdown()

return result_iterator(executor)
raise

def __len__(self) -> int:
if isinstance(self.datapipe, Sized):
Expand Down

0 comments on commit ffab0c1

Please sign in to comment.