Skip to content

Commit

Permalink
bpo-29842: Introduce a prefetch parameter to Executor.map to handle l…
Browse files Browse the repository at this point in the history
…arge iterators
  • Loading branch information
Jason-Y-Z committed Feb 3, 2024
1 parent 28bb296 commit d8cbcb5
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 11 deletions.
9 changes: 8 additions & 1 deletion Doc/library/concurrent.futures.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Executor Objects
future = executor.submit(pow, 323, 1235)
print(future.result())

.. method:: map(fn, *iterables, timeout=None, chunksize=1)
.. method:: map(fn, *iterables, timeout=None, chunksize=1, prefetch=None)

Similar to :func:`map(fn, *iterables) <map>` except:

Expand All @@ -65,9 +65,16 @@ Executor Objects
performance compared to the default size of 1. With
:class:`ThreadPoolExecutor`, *chunksize* has no effect.

By default, all tasks are queued. An explicit *prefetch* count may be
provided to specify how many extra tasks, beyond the number of workers,
should be queued.

.. versionchanged:: 3.5
Added the *chunksize* argument.

.. versionchanged:: 3.13
Added the *prefetch* argument.

.. method:: shutdown(wait=True, *, cancel_futures=False)

Signal the executor that it should free any resources that it is using
Expand Down
36 changes: 28 additions & 8 deletions Lib/concurrent/futures/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
__author__ = 'Brian Quinlan ([email protected])'

import collections
import itertools
import logging
import threading
import time
Expand Down Expand Up @@ -580,7 +581,7 @@ def submit(self, fn, /, *args, **kwargs):
"""
raise NotImplementedError()

def map(self, fn, *iterables, timeout=None, chunksize=1):
def map(self, fn, *iterables, timeout=None, chunksize=1, prefetch=None):
"""Returns an iterator equivalent to map(fn, iter).
Args:
Expand All @@ -592,6 +593,8 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):
before being passed to a child process. This argument is only
used by ProcessPoolExecutor; it is ignored by
ThreadPoolExecutor.
prefetch: The number of chunks to queue beyond the number of
workers on the executor. If None, all chunks are queued.
Returns:
An iterator equivalent to: map(func, *iterables) but the calls may
Expand All @@ -604,25 +607,42 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):
"""
if timeout is not None:
end_time = timeout + time.monotonic()
if prefetch is not None and prefetch < 0:
raise ValueError("prefetch count may not be negative")

fs = [self.submit(fn, *args) for args in zip(*iterables)]
all_args = zip(*iterables)
if prefetch is None:
fs = collections.deque(self.submit(fn, *args) for args in all_args)
else:
fs = collections.deque()
for idx, args in enumerate(all_args):
if idx >= self._max_workers + prefetch:
break
fs.append(self.submit(fn, *args))

# Yield must be hidden in closure so that the futures are submitted
# before the first iterator value is required.
def result_iterator():
def result_iterator(all_args=None):
try:
# reverse to keep finishing order
fs.reverse()
while fs:
# Careful not to keep a reference to the popped future
if timeout is None:
yield _result_or_cancel(fs.pop())
yield _result_or_cancel(fs.popleft())
else:
yield _result_or_cancel(fs.pop(), end_time - time.monotonic())
yield _result_or_cancel(fs.popleft(), end_time - time.monotonic())

# Submit the next task
if all_args:
try:
args = next(all_args)
except StopIteration:
all_args = None
else:
fs.append(self.submit(fn, *args))
finally:
for future in fs:
future.cancel()
return result_iterator()
return result_iterator(all_args)

def shutdown(self, wait=True, *, cancel_futures=False):
"""Clean-up the resources associated with the Executor.
Expand Down
6 changes: 4 additions & 2 deletions Lib/concurrent/futures/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def submit(self, fn, /, *args, **kwargs):
return f
submit.__doc__ = _base.Executor.submit.__doc__

def map(self, fn, *iterables, timeout=None, chunksize=1):
def map(self, fn, *iterables, timeout=None, chunksize=1, prefetch=None):
"""Returns an iterator equivalent to map(fn, iter).
Args:
Expand All @@ -823,6 +823,8 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):
chunksize: If greater than one, the iterables will be chopped into
chunks of size chunksize and submitted to the process pool.
If set to one, the items in the list will be sent one at a time.
prefetch: The number of chunks to queue beyond the number of
workers on the executor. If None, all chunks are queued.
Returns:
An iterator equivalent to: map(func, *iterables) but the calls may
Expand All @@ -838,7 +840,7 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):

results = super().map(partial(_process_chunk, fn),
itertools.batched(zip(*iterables), chunksize),
timeout=timeout)
timeout=timeout, prefetch=prefetch)
return _chain_from_iterable_of_lists(results)

def shutdown(self, wait=True, *, cancel_futures=False):
Expand Down
13 changes: 13 additions & 0 deletions Lib/test/test_concurrent_futures/test_thread_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def record_finished(n):
self.executor.shutdown(wait=True)
self.assertCountEqual(finished, range(10))

def test_map_on_infinite_iterator(self):
import itertools
def identity(x):
return x

mapobj = self.executor.map(identity, itertools.count(0), prefetch=1)
# Get one result, which shows we handle infinite inputs
# without waiting for all work to be dispatched
res = next(mapobj)
mapobj.close() # Make sure futures cancelled

self.assertEqual(res, 0)

def test_default_workers(self):
executor = self.executor_type()
expected = min(32, (os.process_cpu_count() or 1) + 4)
Expand Down

0 comments on commit d8cbcb5

Please sign in to comment.