Skip to content

Commit

Permalink
bpo-29842: Introduce a prefetch parameter to Executor.map to handle l…
Browse files Browse the repository at this point in the history
…arge iterators
  • Loading branch information
Jason-Y-Z committed Feb 4, 2024
1 parent 28bb296 commit 46ea84e
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 18 deletions.
9 changes: 8 additions & 1 deletion Doc/library/concurrent.futures.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Executor Objects
future = executor.submit(pow, 323, 1235)
print(future.result())

.. method:: map(fn, *iterables, timeout=None, chunksize=1)
.. method:: map(fn, *iterables, timeout=None, chunksize=1, prefetch=None)

Similar to :func:`map(fn, *iterables) <map>` except:

Expand All @@ -65,9 +65,16 @@ Executor Objects
performance compared to the default size of 1. With
:class:`ThreadPoolExecutor`, *chunksize* has no effect.

By default, all tasks are queued. An explicit *prefetch* count may be
provided to specify how many extra tasks, beyond the number of workers,
should be queued.

.. versionchanged:: 3.5
Added the *chunksize* argument.

.. versionchanged:: 3.13
Added the *prefetch* argument.

.. method:: shutdown(wait=True, *, cancel_futures=False)

Signal the executor that it should free any resources that it is using
Expand Down
47 changes: 39 additions & 8 deletions Lib/concurrent/futures/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
import time
import types
import weakref

FIRST_COMPLETED = 'FIRST_COMPLETED'
FIRST_EXCEPTION = 'FIRST_EXCEPTION'
Expand Down Expand Up @@ -569,6 +570,15 @@ def set_exception(self, exception):
class Executor(object):
"""This is an abstract base class for concrete asynchronous executors."""

def __init__(self, max_workers=None):
"""Initializes a new Executor instance.
Args:
max_workers: The maximum number of workers that can be used to
execute the given calls.
"""
self._max_workers = max_workers

def submit(self, fn, /, *args, **kwargs):
"""Submits a callable to be executed with the given arguments.
Expand All @@ -580,7 +590,7 @@ def submit(self, fn, /, *args, **kwargs):
"""
raise NotImplementedError()

def map(self, fn, *iterables, timeout=None, chunksize=1):
def map(self, fn, *iterables, timeout=None, chunksize=1, prefetch=None):
"""Returns an iterator equivalent to map(fn, iter).
Args:
Expand All @@ -592,6 +602,8 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):
before being passed to a child process. This argument is only
used by ProcessPoolExecutor; it is ignored by
ThreadPoolExecutor.
prefetch: The number of chunks to queue beyond the number of
workers on the executor. If None, all chunks are queued.
Returns:
An iterator equivalent to: map(func, *iterables) but the calls may
Expand All @@ -604,25 +616,44 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):
"""
if timeout is not None:
end_time = timeout + time.monotonic()
if prefetch is not None and prefetch < 0:
raise ValueError("prefetch count may not be negative")

fs = [self.submit(fn, *args) for args in zip(*iterables)]
all_args = zip(*iterables)
if prefetch is None:
fs = collections.deque(self.submit(fn, *args) for args in all_args)
else:
fs = collections.deque()
for idx, args in enumerate(all_args):
if idx >= self._max_workers + prefetch:
break
fs.append(self.submit(fn, *args))

# Yield must be hidden in closure so that the futures are submitted
# before the first iterator value is required.
def result_iterator():
def result_iterator(all_args, executor_ref):
try:
# reverse to keep finishing order
fs.reverse()
while fs:
# Careful not to keep a reference to the popped future
if timeout is None:
yield _result_or_cancel(fs.pop())
yield _result_or_cancel(fs.popleft())
else:
yield _result_or_cancel(fs.pop(), end_time - time.monotonic())
yield _result_or_cancel(
fs.popleft(), end_time - time.monotonic()
)

# Submit the next task if any and if the executor exists
if executor_ref():
try:
args = next(all_args)
except StopIteration:
pass
else:
fs.append(executor_ref().submit(fn, *args))
finally:
for future in fs:
future.cancel()
return result_iterator()
return result_iterator(all_args, weakref.ref(self))

def shutdown(self, wait=True, *, cancel_futures=False):
"""Clean-up the resources associated with the Executor.
Expand Down
16 changes: 8 additions & 8 deletions Lib/concurrent/futures/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,19 +656,17 @@ def __init__(self, max_workers=None, mp_context=None,
_check_system_limits()

if max_workers is None:
self._max_workers = os.process_cpu_count() or 1
max_workers = os.process_cpu_count() or 1
if sys.platform == 'win32':
self._max_workers = min(_MAX_WINDOWS_WORKERS,
self._max_workers)
max_workers = min(_MAX_WINDOWS_WORKERS, max_workers)
else:
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")
elif (sys.platform == 'win32' and
max_workers > _MAX_WINDOWS_WORKERS):
max_workers > _MAX_WINDOWS_WORKERS):
raise ValueError(
f"max_workers must be <= {_MAX_WINDOWS_WORKERS}")

self._max_workers = max_workers
super().__init__(max_workers)

if mp_context is None:
if max_tasks_per_child is not None:
Expand Down Expand Up @@ -812,7 +810,7 @@ def submit(self, fn, /, *args, **kwargs):
return f
submit.__doc__ = _base.Executor.submit.__doc__

def map(self, fn, *iterables, timeout=None, chunksize=1):
def map(self, fn, *iterables, timeout=None, chunksize=1, prefetch=None):
"""Returns an iterator equivalent to map(fn, iter).
Args:
Expand All @@ -823,6 +821,8 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):
chunksize: If greater than one, the iterables will be chopped into
chunks of size chunksize and submitted to the process pool.
If set to one, the items in the list will be sent one at a time.
prefetch: The number of chunks to queue beyond the number of
workers on the executor. If None, all chunks are queued.
Returns:
An iterator equivalent to: map(func, *iterables) but the calls may
Expand All @@ -838,7 +838,7 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):

results = super().map(partial(_process_chunk, fn),
itertools.batched(zip(*iterables), chunksize),
timeout=timeout)
timeout=timeout, prefetch=prefetch)
return _chain_from_iterable_of_lists(results)

def shutdown(self, wait=True, *, cancel_futures=False):
Expand Down
2 changes: 1 addition & 1 deletion Lib/concurrent/futures/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(self, max_workers=None, thread_name_prefix='',
if initializer is not None and not callable(initializer):
raise TypeError("initializer must be a callable")

self._max_workers = max_workers
super().__init__(max_workers)
self._work_queue = queue.SimpleQueue()
self._idle_semaphore = threading.Semaphore(0)
self._threads = set()
Expand Down
13 changes: 13 additions & 0 deletions Lib/test/test_concurrent_futures/test_thread_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ def record_finished(n):
self.executor.shutdown(wait=True)
self.assertCountEqual(finished, range(10))

def test_map_on_infinite_iterator(self):
import itertools
def identity(x):
return x

mapobj = self.executor.map(identity, itertools.count(0), prefetch=1)
# Get one result, which shows we handle infinite inputs
# without waiting for all work to be dispatched
res = next(mapobj)
mapobj.close() # Make sure futures cancelled

self.assertEqual(res, 0)

def test_default_workers(self):
executor = self.executor_type()
expected = min(32, (os.process_cpu_count() or 1) + 4)
Expand Down

0 comments on commit 46ea84e

Please sign in to comment.