Skip to content

Commit

Permalink
[SPARK-23754][PYTHON] Re-raising StopIteration in client code
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Make sure that `StopIteration`s raised in users' code do not silently interrupt processing by spark, but are raised as exceptions to the users. The users' functions are wrapped in `safe_iter` (in `shuffle.py`), which re-raises `StopIteration`s as `RuntimeError`s

## How was this patch tested?

Unit tests, making sure that the exceptions are indeed raised. I am not sure how to check whether a `Py4JJavaError` contains my exception, so I simply looked for the exception message in the java exception's `toString`. Can you propose a better way?

## License

This is my original work, licensed in the same way as spark

Author: e-dorigatti <[email protected]>
Author: edorigatti <[email protected]>

Closes #21383 from e-dorigatti/fix_spark_23754.
  • Loading branch information
e-dorigatti authored and HyukjinKwon committed May 30, 2018
1 parent a4be981 commit 0ebb0c0
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 11 deletions.
18 changes: 15 additions & 3 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pyspark.shuffle import Aggregator, ExternalMerger, \
get_used_memory, ExternalSorter, ExternalGroupBy
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.util import fail_on_stopiteration


__all__ = ["RDD"]
Expand Down Expand Up @@ -339,7 +340,7 @@ def map(self, f, preservesPartitioning=False):
[('a', 1), ('b', 1), ('c', 1)]
"""
def func(_, iterator):
return map(f, iterator)
return map(fail_on_stopiteration(f), iterator)
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def flatMap(self, f, preservesPartitioning=False):
Expand All @@ -354,7 +355,7 @@ def flatMap(self, f, preservesPartitioning=False):
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
def func(s, iterator):
return chain.from_iterable(map(f, iterator))
return chain.from_iterable(map(fail_on_stopiteration(f), iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def mapPartitions(self, f, preservesPartitioning=False):
Expand Down Expand Up @@ -417,7 +418,7 @@ def filter(self, f):
[2, 4]
"""
def func(iterator):
return filter(f, iterator)
return filter(fail_on_stopiteration(f), iterator)
return self.mapPartitions(func, True)

def distinct(self, numPartitions=None):
Expand Down Expand Up @@ -798,6 +799,8 @@ def foreach(self, f):
>>> def f(x): print(x)
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
f = fail_on_stopiteration(f)

def processPartition(iterator):
for x in iterator:
f(x)
Expand Down Expand Up @@ -847,6 +850,8 @@ def reduce(self, f):
...
ValueError: Can not reduce() empty RDD
"""
f = fail_on_stopiteration(f)

def func(iterator):
iterator = iter(iterator)
try:
Expand Down Expand Up @@ -918,6 +923,8 @@ def fold(self, zeroValue, op):
>>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
15
"""
op = fail_on_stopiteration(op)

def func(iterator):
acc = zeroValue
for obj in iterator:
Expand Down Expand Up @@ -950,6 +957,9 @@ def aggregate(self, zeroValue, seqOp, combOp):
>>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
(0, 0)
"""
seqOp = fail_on_stopiteration(seqOp)
combOp = fail_on_stopiteration(combOp)

def func(iterator):
acc = zeroValue
for obj in iterator:
Expand Down Expand Up @@ -1643,6 +1653,8 @@ def reduceByKeyLocally(self, func):
>>> sorted(rdd.reduceByKeyLocally(add).items())
[('a', 2), ('b', 1)]
"""
func = fail_on_stopiteration(func)

def reducePartition(iterator):
m = {}
for k, v in iterator:
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pyspark.heapq3 as heapq
from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
CompressedSerializer, AutoBatchedSerializer
from pyspark.util import fail_on_stopiteration


try:
Expand Down Expand Up @@ -94,9 +95,9 @@ class Aggregator(object):
"""

def __init__(self, createCombiner, mergeValue, mergeCombiners):
self.createCombiner = createCombiner
self.mergeValue = mergeValue
self.mergeCombiners = mergeCombiners
self.createCombiner = fail_on_stopiteration(createCombiner)
self.mergeValue = fail_on_stopiteration(mergeValue)
self.mergeCombiners = fail_on_stopiteration(mergeCombiners)


class SimpleAggregator(Aggregator):
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,22 @@ def __call__(self, x):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

def test_stopiteration_in_udf(self):
# test for SPARK-23754
from pyspark.sql.functions import udf
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

with self.assertRaises(Py4JJavaError) as cm:
self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show()

self.assertIn(
"Caught StopIteration thrown from user's code; failing the task",
cm.exception.java_exception.toString()
)

def test_validate_column_types(self):
from pyspark.sql.functions import udf, to_json
from pyspark.sql.column import _to_java_column
Expand Down
14 changes: 12 additions & 2 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\
to_arrow_type, to_arrow_schema
from pyspark.util import _get_argspec
from pyspark.util import _get_argspec, fail_on_stopiteration

__all__ = ["UDFRegistration"]

Expand Down Expand Up @@ -157,7 +157,17 @@ def _create_judf(self):
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

wrapped_func = _wrap_function(sc, self.func, self.returnType)
func = fail_on_stopiteration(self.func)

# for pandas UDFs the worker needs to know if the function takes
# one or two arguments, but the signature is lost when wrapping with
# fail_on_stopiteration, so we store it here
if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
func._argspec = _get_argspec(self.func)

wrapped_func = _wrap_function(sc, func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
Expand Down
53 changes: 53 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,37 @@ def gen_gs(N, step=1):
self.assertEqual(k, len(vs))
self.assertEqual(list(range(k)), list(vs))

def test_stopiteration_is_raised(self):

def stopit(*args, **kwargs):
raise StopIteration()

def legit_create_combiner(x):
return [x]

def legit_merge_value(x, y):
return x.append(y) or x

def legit_merge_combiners(x, y):
return x.extend(y) or x

data = [(x % 2, x) for x in range(100)]

# wrong create combiner
m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20)
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
m.mergeValues(data)

# wrong merge value
m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20)
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
m.mergeValues(data)

# wrong merge combiners
m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20)
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))


class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
Expand Down Expand Up @@ -1246,6 +1277,28 @@ def test_pipe_unicode(self):
result = rdd.pipe('cat').collect()
self.assertEqual(data, result)

def test_stopiteration_in_client_code(self):

def stopit(*x):
raise StopIteration()

seq_rdd = self.sc.parallelize(range(10))
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))

self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit)
self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit)

# the exception raised is non-deterministic
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, lambda *x: 1, stopit)


class ProfilerTests(PySparkTestCase):

Expand Down
28 changes: 25 additions & 3 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,16 @@ def _get_argspec(f):
"""
Get argspec of a function. Supports both Python 2 and Python 3.
"""
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
# See SPARK-23569.
if sys.version_info[0] < 3:

if hasattr(f, '_argspec'):
# only used for pandas UDF: they wrap the user function, losing its signature
# workers need this signature, so UDF saves it here
argspec = f._argspec
elif sys.version_info[0] < 3:
argspec = inspect.getargspec(f)
else:
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
# See SPARK-23569.
argspec = inspect.getfullargspec(f)
return argspec

Expand Down Expand Up @@ -89,6 +94,23 @@ def majorMinorVersion(sparkVersion):
" version numbers.")


def fail_on_stopiteration(f):
"""
Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
prevents silent loss of data when 'f' is used in a for loop
"""
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except StopIteration as exc:
raise RuntimeError(
"Caught StopIteration thrown from user's code; failing the task",
exc
)

return wrapper


if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()
Expand Down

0 comments on commit 0ebb0c0

Please sign in to comment.