Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23754][BRANCH-2.3][PYTHON] Re-raising StopIteration in client code #21463

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pyspark.shuffle import Aggregator, ExternalMerger, \
get_used_memory, ExternalSorter, ExternalGroupBy
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.util import fail_on_stopiteration


__all__ = ["RDD"]
Expand Down Expand Up @@ -338,7 +339,7 @@ def map(self, f, preservesPartitioning=False):
[('a', 1), ('b', 1), ('c', 1)]
"""
def func(_, iterator):
return map(f, iterator)
return map(fail_on_stopiteration(f), iterator)
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def flatMap(self, f, preservesPartitioning=False):
Expand All @@ -353,7 +354,7 @@ def flatMap(self, f, preservesPartitioning=False):
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
def func(s, iterator):
return chain.from_iterable(map(f, iterator))
return chain.from_iterable(map(fail_on_stopiteration(f), iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def mapPartitions(self, f, preservesPartitioning=False):
Expand Down Expand Up @@ -416,7 +417,7 @@ def filter(self, f):
[2, 4]
"""
def func(iterator):
return filter(f, iterator)
return filter(fail_on_stopiteration(f), iterator)
return self.mapPartitions(func, True)

def distinct(self, numPartitions=None):
Expand Down Expand Up @@ -797,6 +798,8 @@ def foreach(self, f):
>>> def f(x): print(x)
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
f = fail_on_stopiteration(f)

def processPartition(iterator):
for x in iterator:
f(x)
Expand Down Expand Up @@ -846,6 +849,8 @@ def reduce(self, f):
...
ValueError: Can not reduce() empty RDD
"""
f = fail_on_stopiteration(f)

def func(iterator):
iterator = iter(iterator)
try:
Expand Down Expand Up @@ -917,6 +922,8 @@ def fold(self, zeroValue, op):
>>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
15
"""
op = fail_on_stopiteration(op)

def func(iterator):
acc = zeroValue
for obj in iterator:
Expand Down Expand Up @@ -949,6 +956,9 @@ def aggregate(self, zeroValue, seqOp, combOp):
>>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
(0, 0)
"""
seqOp = fail_on_stopiteration(seqOp)
combOp = fail_on_stopiteration(combOp)

def func(iterator):
acc = zeroValue
for obj in iterator:
Expand Down Expand Up @@ -1642,6 +1652,8 @@ def reduceByKeyLocally(self, func):
>>> sorted(rdd.reduceByKeyLocally(add).items())
[('a', 2), ('b', 1)]
"""
func = fail_on_stopiteration(func)

def reducePartition(iterator):
m = {}
for k, v in iterator:
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import pyspark.heapq3 as heapq
from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
CompressedSerializer, AutoBatchedSerializer
from pyspark.util import fail_on_stopiteration


try:
Expand Down Expand Up @@ -93,9 +94,9 @@ class Aggregator(object):
"""

def __init__(self, createCombiner, mergeValue, mergeCombiners):
self.createCombiner = createCombiner
self.mergeValue = mergeValue
self.mergeCombiners = mergeCombiners
self.createCombiner = fail_on_stopiteration(createCombiner)
self.mergeValue = fail_on_stopiteration(mergeValue)
self.mergeCombiners = fail_on_stopiteration(mergeCombiners)


class SimpleAggregator(Aggregator):
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,22 @@ def __call__(self, x):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

def test_stopiteration_in_udf(self):
# test for SPARK-23754
from pyspark.sql.functions import udf
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

with self.assertRaises(Py4JJavaError) as cm:
self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show()

self.assertIn(
"Caught StopIteration thrown from user's code; failing the task",
cm.exception.java_exception.toString()
)

def test_validate_column_types(self):
from pyspark.sql.functions import udf, to_json
from pyspark.sql.column import _to_java_column
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string, \
to_arrow_type, to_arrow_schema
from pyspark.util import fail_on_stopiteration

__all__ = ["UDFRegistration"]

Expand Down Expand Up @@ -154,7 +155,8 @@ def _create_judf(self):
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

wrapped_func = _wrap_function(sc, self.func, self.returnType)
func = fail_on_stopiteration(self.func)
wrapped_func = _wrap_function(sc, func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
Expand Down
53 changes: 53 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,37 @@ def gen_gs(N, step=1):
self.assertEqual(k, len(vs))
self.assertEqual(list(range(k)), list(vs))

def test_stopiteration_is_raised(self):

def stopit(*args, **kwargs):
raise StopIteration()

def legit_create_combiner(x):
return [x]

def legit_merge_value(x, y):
return x.append(y) or x

def legit_merge_combiners(x, y):
return x.extend(y) or x

data = [(x % 2, x) for x in range(100)]

# wrong create combiner
m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20)
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
m.mergeValues(data)

# wrong merge value
m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20)
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
m.mergeValues(data)

# wrong merge combiners
m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20)
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))


class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
Expand Down Expand Up @@ -1239,6 +1270,28 @@ def test_pipe_functions(self):
self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
self.assertEqual([], rdd.pipe('grep 4').collect())

def test_stopiteration_in_client_code(self):

def stopit(*x):
raise StopIteration()

seq_rdd = self.sc.parallelize(range(10))
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))

self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit)
self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit)

# the exception raised is non-deterministic
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, lambda *x: 1, stopit)


class ProfilerTests(PySparkTestCase):

Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ def _exception_message(excp):
return str(excp)


def fail_on_stopiteration(f):
"""
Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
prevents silent loss of data when 'f' is used in a for loop
"""
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except StopIteration as exc:
raise RuntimeError(
"Caught StopIteration thrown from user's code; failing the task",
exc
)

return wrapper


if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()
Expand Down