Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23754][Python] Re-raising StopIteration in client code #21383

Closed
wants to merge 14 commits into from
18 changes: 15 additions & 3 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from pyspark.shuffle import Aggregator, ExternalMerger, \
get_used_memory, ExternalSorter, ExternalGroupBy
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.util import fail_on_stopiteration


__all__ = ["RDD"]
Expand Down Expand Up @@ -332,7 +333,7 @@ def map(self, f, preservesPartitioning=False):
[('a', 1), ('b', 1), ('c', 1)]
"""
def func(_, iterator):
return map(f, iterator)
return map(fail_on_stopiteration(f), iterator)
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def flatMap(self, f, preservesPartitioning=False):
Expand All @@ -347,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False):
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
def func(s, iterator):
return chain.from_iterable(map(f, iterator))
return chain.from_iterable(map(fail_on_stopiteration(f), iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def mapPartitions(self, f, preservesPartitioning=False):
Expand Down Expand Up @@ -410,7 +411,7 @@ def filter(self, f):
[2, 4]
"""
def func(iterator):
return filter(f, iterator)
return filter(fail_on_stopiteration(f), iterator)
return self.mapPartitions(func, True)

def distinct(self, numPartitions=None):
Expand Down Expand Up @@ -791,6 +792,8 @@ def foreach(self, f):
>>> def f(x): print(x)
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
f = fail_on_stopiteration(f)

def processPartition(iterator):
for x in iterator:
f(x)
Expand Down Expand Up @@ -840,6 +843,8 @@ def reduce(self, f):
...
ValueError: Can not reduce() empty RDD
"""
f = fail_on_stopiteration(f)

def func(iterator):
iterator = iter(iterator)
try:
Expand Down Expand Up @@ -911,6 +916,8 @@ def fold(self, zeroValue, op):
>>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
15
"""
op = fail_on_stopiteration(op)

def func(iterator):
acc = zeroValue
for obj in iterator:
Expand Down Expand Up @@ -943,6 +950,9 @@ def aggregate(self, zeroValue, seqOp, combOp):
>>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
(0, 0)
"""
seqOp = fail_on_stopiteration(seqOp)
combOp = fail_on_stopiteration(combOp)

def func(iterator):
acc = zeroValue
for obj in iterator:
Expand Down Expand Up @@ -1636,6 +1646,8 @@ def reduceByKeyLocally(self, func):
>>> sorted(rdd.reduceByKeyLocally(add).items())
[('a', 2), ('b', 1)]
"""
func = fail_on_stopiteration(func)

def reducePartition(iterator):
m = {}
for k, v in iterator:
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import pyspark.heapq3 as heapq
from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
CompressedSerializer, AutoBatchedSerializer
from pyspark.util import fail_on_stopiteration


try:
Expand Down Expand Up @@ -94,9 +95,9 @@ class Aggregator(object):
"""

def __init__(self, createCombiner, mergeValue, mergeCombiners):
self.createCombiner = createCombiner
self.mergeValue = mergeValue
self.mergeCombiners = mergeCombiners
self.createCombiner = fail_on_stopiteration(createCombiner)
self.mergeValue = fail_on_stopiteration(mergeValue)
self.mergeCombiners = fail_on_stopiteration(mergeCombiners)


class SimpleAggregator(Aggregator):
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,22 @@ def __call__(self, x):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

def test_stopiteration_in_udf(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add tests for pandas_udf?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also merge this PR first. I can follow up with the pandas udf tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please, I am not really familiar with UDFs in general

# test for SPARK-23754
from pyspark.sql.functions import udf
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

with self.assertRaises(Py4JJavaError) as cm:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for cm

self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show()

self.assertIn(
"Caught StopIteration thrown from user's code; failing the task",
cm.exception.java_exception.toString()
)

def test_validate_column_types(self):
from pyspark.sql.functions import udf, to_json
from pyspark.sql.column import _to_java_column
Expand Down
14 changes: 12 additions & 2 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\
to_arrow_type, to_arrow_schema
from pyspark.util import _get_argspec
from pyspark.util import _get_argspec, fail_on_stopiteration

__all__ = ["UDFRegistration"]

Expand Down Expand Up @@ -157,7 +157,17 @@ def _create_judf(self):
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

wrapped_func = _wrap_function(sc, self.func, self.returnType)
func = fail_on_stopiteration(self.func)

# for pandas UDFs the worker needs to know if the function takes
# one or two arguments, but the signature is lost when wrapping with
# fail_on_stopiteration, so we store it here
if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
func._argspec = _get_argspec(self.func)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense for fail_one_stopiteration to keep the function signature instead of restoring them here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how to do that, though, can you suggest a way? Originally this hack was in fail_one_stopiteration, but then it was decided to restrict its scope as much as possible

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I saw @HyukjinKwon comment here:
#21383 (comment)

Copy link
Member

@HyukjinKwon HyukjinKwon May 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems to be difficult in particular in Python 2. I'm not aware of a clean and straightforward way without a hack to copy the signature in Python 2 too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I remember we still support pandas udf with Python 2? Does the resolution here not work with Pandas UDF and Python 2?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, we definitely support. The current approach probably wouldn't change anything we supported before. I believe the builtin functions in Python 2 don't already with with Pandas UDFs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks for the clarification.


wrapped_func = _wrap_function(sc, func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt, self.evalType, self.deterministic)
Expand Down
53 changes: 53 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,37 @@ def gen_gs(N, step=1):
self.assertEqual(k, len(vs))
self.assertEqual(list(range(k)), list(vs))

def test_stopiteration_is_raised(self):

def stopit(*args, **kwargs):
raise StopIteration()

def legit_create_combiner(x):
return [x]

def legit_merge_value(x, y):
return x.append(y) or x

def legit_merge_combiners(x, y):
return x.extend(y) or x

data = [(x % 2, x) for x in range(100)]

# wrong create combiner
m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20)
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cm looks unused.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's pick up one explicit exception here too.

Copy link
Member

@viirya viirya Jun 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the change in this test, you still can get a StopIteration in this test. Isn't?

m.mergeValues(data)

# wrong merge value
m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20)
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
m.mergeValues(data)

# wrong merge combiners
m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20)
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))


class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
Expand Down Expand Up @@ -1246,6 +1277,28 @@ def test_pipe_unicode(self):
result = rdd.pipe('cat').collect()
self.assertEqual(data, result)

def test_stopiteration_in_client_code(self):

def stopit(*x):
raise StopIteration()

seq_rdd = self.sc.parallelize(range(10))
keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))

self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit)
self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit)
self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit)

# the exception raised is non-deterministic
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, stopit, lambda *x: 1)
self.assertRaises((Py4JJavaError, RuntimeError),
seq_rdd.aggregate, 0, lambda *x: 1, stopit)


class ProfilerTests(PySparkTestCase):

Expand Down
28 changes: 25 additions & 3 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,16 @@ def _get_argspec(f):
"""
Get argspec of a function. Supports both Python 2 and Python 3.
"""
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
# See SPARK-23569.
if sys.version_info[0] < 3:

if hasattr(f, '_argspec'):
# only used for pandas UDF: they wrap the user function, losing its signature
# workers need this signature, so UDF saves it here
argspec = f._argspec
elif sys.version_info[0] < 3:
argspec = inspect.getargspec(f)
else:
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
# See SPARK-23569.
argspec = inspect.getfullargspec(f)
return argspec

Expand Down Expand Up @@ -89,6 +94,23 @@ def majorMinorVersion(sparkVersion):
" version numbers.")


def fail_on_stopiteration(f):
"""
Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
prevents silent loss of data when 'f' is used in a for loop
"""
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except StopIteration as exc:
raise RuntimeError(
"Caught StopIteration thrown from user's code; failing the task",
exc
)

return wrapper


if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()
Expand Down