Skip to content

Commit

Permalink
[Core] Make multiprocessing pool raise TypeError when provided a non-…
Browse files Browse the repository at this point in the history
…iterable in imap or imap_unordered ray-project#31799

This PR modifies ray.util.multiprocessing.Pool so that pool.imap and pool.imap_unordered throw TypeError when given non-iterable inputs.
Signed-off-by: Cade Daniel <[email protected]>
Signed-off-by: elliottower <[email protected]>
  • Loading branch information
cadedaniel authored and elliottower committed Apr 22, 2023
1 parent cd88663 commit d416770
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 30 deletions.
16 changes: 3 additions & 13 deletions python/ray/tests/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import time
import random
from collections import defaultdict
import warnings
import queue
import math

Expand Down Expand Up @@ -509,26 +508,17 @@ def f(args):
result_iter.next()


@pytest.mark.filterwarnings(
"default:Passing a non-iterable argument:ray.util.annotations.RayDeprecationWarning"
)
def test_warn_on_non_iterable_imap_or_imap_unordered(pool):
def test_imap_fail_on_non_iterable(pool):
def fn(_):
pass

non_iterable = 3

with warnings.catch_warnings(record=True) as w:
with pytest.raises(TypeError, match="object is not iterable"):
pool.imap(fn, non_iterable)
assert any(
"Passing a non-iterable argument" in str(warning.message) for warning in w
)

with warnings.catch_warnings(record=True) as w:
with pytest.raises(TypeError, match="object is not iterable"):
pool.imap_unordered(fn, non_iterable)
assert any(
"Passing a non-iterable argument" in str(warning.message) for warning in w
)


@pytest.mark.parametrize("use_iter", [True, False])
Expand Down
18 changes: 1 addition & 17 deletions python/ray/util/multiprocessing/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
import queue
import sys
import threading
import warnings
import time
from multiprocessing import TimeoutError
from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Tuple

import ray
from ray.util import log_once
from ray.util.annotations import RayDeprecationWarning

try:
from joblib._parallel_backends import SafeFunction
Expand Down Expand Up @@ -390,21 +388,7 @@ def __init__(self, pool, func, iterable, chunksize=None):
# submitted chunks. Ordering mirrors that in the in the ResultThread.
self._submitted_chunks = []
self._ready_objects = collections.deque()
try:
self._iterator = iter(iterable)
except TypeError:
warnings.warn(
"Passing a non-iterable argument to the "
"ray.util.multiprocessing.Pool imap and imap_unordered "
"methods is deprecated as of Ray 2.3 and "
" will be removed in a future release. See "
"https://github.com/ray-project/ray/issues/24237 for more "
"information.",
category=RayDeprecationWarning,
stacklevel=3,
)
iterable = [iterable]
self._iterator = iter(iterable)
self._iterator = iter(iterable)
if isinstance(iterable, collections.abc.Iterator):
# Got iterator (which has no len() function).
# Make default chunksize 1 instead of using _calculate_chunksize().
Expand Down

0 comments on commit d416770

Please sign in to comment.