From d4167700583e1c361fd8bb2ab154bd13911e8dd6 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Wed, 8 Mar 2023 13:07:18 -0800 Subject: [PATCH] [Core] Make multiprocessing pool raise TypeError when provided a non-iterable in imap or imap_unordered #31799 This PR modifies ray.util.multiprocessing.Pool so that pool.imap and pool.imap_unordered throw TypeError when given non-iterable inputs. Signed-off-by: Cade Daniel Signed-off-by: elliottower --- python/ray/tests/test_multiprocessing.py | 16 +++------------- python/ray/util/multiprocessing/pool.py | 18 +----------------- 2 files changed, 4 insertions(+), 30 deletions(-) diff --git a/python/ray/tests/test_multiprocessing.py b/python/ray/tests/test_multiprocessing.py index 07051b1ef365..7a82f39933f5 100644 --- a/python/ray/tests/test_multiprocessing.py +++ b/python/ray/tests/test_multiprocessing.py @@ -6,7 +6,6 @@ import time import random from collections import defaultdict -import warnings import queue import math @@ -509,26 +508,17 @@ def f(args): result_iter.next() -@pytest.mark.filterwarnings( - "default:Passing a non-iterable argument:ray.util.annotations.RayDeprecationWarning" -) -def test_warn_on_non_iterable_imap_or_imap_unordered(pool): +def test_imap_fail_on_non_iterable(pool): def fn(_): pass non_iterable = 3 - with warnings.catch_warnings(record=True) as w: + with pytest.raises(TypeError, match="object is not iterable"): pool.imap(fn, non_iterable) - assert any( - "Passing a non-iterable argument" in str(warning.message) for warning in w - ) - with warnings.catch_warnings(record=True) as w: + with pytest.raises(TypeError, match="object is not iterable"): pool.imap_unordered(fn, non_iterable) - assert any( - "Passing a non-iterable argument" in str(warning.message) for warning in w - ) @pytest.mark.parametrize("use_iter", [True, False]) diff --git a/python/ray/util/multiprocessing/pool.py b/python/ray/util/multiprocessing/pool.py index fa659a5b4ae5..299b271ca186 100644 --- a/python/ray/util/multiprocessing/pool.py +++ b/python/ray/util/multiprocessing/pool.py @@ -7,14 +7,12 @@ import queue import sys import threading -import warnings import time from multiprocessing import TimeoutError from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Tuple import ray from ray.util import log_once -from ray.util.annotations import RayDeprecationWarning try: from joblib._parallel_backends import SafeFunction @@ -390,21 +388,7 @@ def __init__(self, pool, func, iterable, chunksize=None): # submitted chunks. Ordering mirrors that in the in the ResultThread. self._submitted_chunks = [] self._ready_objects = collections.deque() - try: - self._iterator = iter(iterable) - except TypeError: - warnings.warn( - "Passing a non-iterable argument to the " - "ray.util.multiprocessing.Pool imap and imap_unordered " - "methods is deprecated as of Ray 2.3 and " - " will be removed in a future release. See " - "https://github.com/ray-project/ray/issues/24237 for more " - "information.", - category=RayDeprecationWarning, - stacklevel=3, - ) - iterable = [iterable] - self._iterator = iter(iterable) + self._iterator = iter(iterable) if isinstance(iterable, collections.abc.Iterator): # Got iterator (which has no len() function). # Make default chunksize 1 instead of using _calculate_chunksize().