diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d5a237a5b2855..14d9128502ab0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -53,6 +53,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync +from pyspark.util import fail_on_stopiteration __all__ = ["RDD"] @@ -339,7 +340,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(f, iterator) + return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -354,7 +355,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(f, iterator)) + return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -417,7 +418,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(f, iterator) + return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -798,6 +799,8 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ + f = fail_on_stopiteration(f) + def processPartition(iterator): for x in iterator: f(x) @@ -847,6 +850,8 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ + f = fail_on_stopiteration(f) + def func(iterator): iterator = iter(iterator) try: @@ -918,6 +923,8 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ + op = fail_on_stopiteration(op) + def func(iterator): acc = zeroValue for obj in iterator: @@ -950,6 +957,9 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ + seqOp = fail_on_stopiteration(seqOp) + combOp = fail_on_stopiteration(combOp) + def func(iterator): acc = zeroValue for obj in iterator: @@ -1643,6 +1653,8 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ + func = fail_on_stopiteration(func) + def reducePartition(iterator): m = {} for k, v in iterator: diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 02c773302e9da..bd0ac0039ffe1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,6 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer +from pyspark.util import fail_on_stopiteration try: @@ -94,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = createCombiner - self.mergeValue = mergeValue - self.mergeCombiners = mergeCombiners + self.createCombiner = fail_on_stopiteration(createCombiner) + self.mergeValue = fail_on_stopiteration(mergeValue) + self.mergeCombiners = fail_on_stopiteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c7bd8f01b907f..a2450932e303d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,6 +900,22 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + def test_stopiteration_in_udf(self): + # test for SPARK-23754 + from pyspark.sql.functions import udf + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + with self.assertRaises(Py4JJavaError) as cm: + self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() + + self.assertIn( + "Caught StopIteration thrown from user's code; failing the task", + cm.exception.java_exception.toString() + ) + def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 9dbe49b831cef..c8fb49d7c2b65 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -25,7 +25,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ to_arrow_type, to_arrow_schema -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration __all__ = ["UDFRegistration"] @@ -157,7 +157,17 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, self.func, self.returnType) + func = fail_on_stopiteration(self.func) + + # for pandas UDFs the worker needs to know if the function takes + # one or two arguments, but the signature is lost when wrapping with + # fail_on_stopiteration, so we store it here + if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): + func._argspec = _get_argspec(self.func) + + wrapped_func = _wrap_function(sc, func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 498d6b57e4353..3b37cc028c1b7 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -161,6 +161,37 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -1246,6 +1277,28 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) + def test_stopiteration_in_client_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + + self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) + self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) + + # the exception raised is non-deterministic + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + class ProfilerTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 59cc2a6329350..e95a9b523393f 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -53,11 +53,16 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. - if sys.version_info[0] < 3: + + if hasattr(f, '_argspec'): + # only used for pandas UDF: they wrap the user function, losing its signature + # workers need this signature, so UDF saves it here + argspec = f._argspec + elif sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. argspec = inspect.getfullargspec(f) return argspec @@ -89,6 +94,23 @@ def majorMinorVersion(sparkVersion): " version numbers.") +def fail_on_stopiteration(f): + """ + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError( + "Caught StopIteration thrown from user's code; failing the task", + exc + ) + + return wrapper + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod()