From ffab0c1e0601d5270c42cdb4f05d18259e6aa4db Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Mon, 6 Mar 2023 09:41:15 +0100 Subject: [PATCH] Ensure shutdown of executor --- .../datapipes/iter/transform/callable.py | 54 ++++++++++--------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index b7493c2f4..efec4a060 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -752,31 +752,37 @@ def _apply_fn(self, data): return tuple(data) if t_flag else data def __iter__(self) -> Iterator[T_co]: - executor = futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) - futures_deque: deque = deque() - has_next = True - itr = iter(self.datapipe) - for _ in range(self.scheduled_tasks): - try: - futures_deque.append(executor.submit(self._apply_fn, next(itr))) - except StopIteration: - has_next = False - break - - # Yield must be hidden in closure so that the futures are submitted - # before the first iterator value is required. - def result_iterator(executor): - while len(futures_deque) > 0: - nonlocal has_next - if has_next: - try: - futures_deque.append(executor.submit(self._apply_fn, next(itr))) - except StopIteration: - has_next = False - yield futures_deque.popleft().result() + try: + executor = futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) + futures_deque: deque = deque() + has_next = True + itr = iter(self.datapipe) + for _ in range(self.scheduled_tasks): + try: + futures_deque.append(executor.submit(self._apply_fn, next(itr))) + except StopIteration: + has_next = False + break + + # Yield must be hidden in closure so that the futures are submitted + # before the first iterator value is required. + def result_iterator(): + try: + while len(futures_deque) > 0: + nonlocal has_next + if has_next: + try: + futures_deque.append(executor.submit(self._apply_fn, next(itr))) + except StopIteration: + has_next = False + yield futures_deque.popleft().result() + finally: + executor.shutdown() + + return result_iterator() + except Exception: executor.shutdown() - - return result_iterator(executor) + raise def __len__(self) -> int: if isinstance(self.datapipe, Sized):