Skip to content

Commit

Permalink
Make ray.util.multiprocessing.Pool imap, imap_unordered raise TypeErr…
Browse files Browse the repository at this point in the history
…or on non-iterable arguments, matching Python multiprocessing.Pool behavior.

Signed-off-by: Cade Daniel <[email protected]>
  • Loading branch information
cadedaniel committed Mar 7, 2023
1 parent 5fe6359 commit 1f63120
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 30 deletions.
16 changes: 3 additions & 13 deletions python/ray/tests/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import time
import random
from collections import defaultdict
import warnings
import queue
import math

Expand Down Expand Up @@ -509,26 +508,17 @@ def f(args):
result_iter.next()


@pytest.mark.filterwarnings(
"default:Passing a non-iterable argument:ray.util.annotations.RayDeprecationWarning"
)
def test_warn_on_non_iterable_imap_or_imap_unordered(pool):
def test_imap_fail_on_non_iterable(pool):
def fn(_):
pass

non_iterable = 3

with warnings.catch_warnings(record=True) as w:
with pytest.raises(TypeError, match="object is not iterable"):
pool.imap(fn, non_iterable)
assert any(
"Passing a non-iterable argument" in str(warning.message) for warning in w
)

with warnings.catch_warnings(record=True) as w:
with pytest.raises(TypeError, match="object is not iterable"):
pool.imap_unordered(fn, non_iterable)
assert any(
"Passing a non-iterable argument" in str(warning.message) for warning in w
)


@pytest.mark.parametrize("use_iter", [True, False])
Expand Down
18 changes: 1 addition & 17 deletions python/ray/util/multiprocessing/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
import queue
import sys
import threading
import warnings
import time
from multiprocessing import TimeoutError
from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Tuple

import ray
from ray.util import log_once
from ray.util.annotations import RayDeprecationWarning

try:
from joblib._parallel_backends import SafeFunction
Expand Down Expand Up @@ -390,21 +388,7 @@ def __init__(self, pool, func, iterable, chunksize=None):
# submitted chunks. Ordering mirrors that in the in the ResultThread.
self._submitted_chunks = []
self._ready_objects = collections.deque()
try:
self._iterator = iter(iterable)
except TypeError:
warnings.warn(
"Passing a non-iterable argument to the "
"ray.util.multiprocessing.Pool imap and imap_unordered "
"methods is deprecated as of Ray 2.3 and "
" will be removed in a future release. See "
"https://github.com/ray-project/ray/issues/24237 for more "
"information.",
category=RayDeprecationWarning,
stacklevel=3,
)
iterable = [iterable]
self._iterator = iter(iterable)
self._iterator = iter(iterable)
if isinstance(iterable, collections.abc.Iterator):
# Got iterator (which has no len() function).
# Make default chunksize 1 instead of using _calculate_chunksize().
Expand Down

0 comments on commit 1f63120

Please sign in to comment.