Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23754][Python] Re-raising StopIteration in client code #21383

Closed
wants to merge 14 commits into from
35 changes: 24 additions & 11 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, ExternalMerger, \
get_used_memory, ExternalSorter, ExternalGroupBy
get_used_memory, ExternalSorter, ExternalGroupBy, safe_iter
from pyspark.traceback_utils import SCCallSiteSync


Expand Down Expand Up @@ -173,6 +173,7 @@ def ignore_unicode_prefix(f):
return f



Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, I would revert unrelated changes to make the backport easier.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not really familiar with the codebase, can you provide more details please?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this new line is an unnecessary change

class Partitioner(object):
def __init__(self, numPartitions, partitionFunc):
self.numPartitions = numPartitions
Expand Down Expand Up @@ -332,7 +333,7 @@ def map(self, f, preservesPartitioning=False):
[('a', 1), ('b', 1), ('c', 1)]
"""
def func(_, iterator):
return map(f, iterator)
return map(safe_iter(f), iterator)
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def flatMap(self, f, preservesPartitioning=False):
Expand All @@ -347,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False):
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
def func(s, iterator):
return chain.from_iterable(map(f, iterator))
return chain.from_iterable(map(safe_iter(f), iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def mapPartitions(self, f, preservesPartitioning=False):
Expand Down Expand Up @@ -410,7 +411,7 @@ def filter(self, f):
[2, 4]
"""
def func(iterator):
return filter(f, iterator)
return filter(safe_iter(f), iterator)
return self.mapPartitions(func, True)

def distinct(self, numPartitions=None):
Expand Down Expand Up @@ -791,9 +792,11 @@ def foreach(self, f):
>>> def f(x): print(x)
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
safe_f = safe_iter(f)

def processPartition(iterator):
for x in iterator:
f(x)
safe_f(x)
return iter([])
self.mapPartitions(processPartition).count() # Force evaluation

Expand Down Expand Up @@ -840,13 +843,15 @@ def reduce(self, f):
...
ValueError: Can not reduce() empty RDD
"""
safe_f = safe_iter(f)

def func(iterator):
iterator = iter(iterator)
try:
initial = next(iterator)
except StopIteration:
return
yield reduce(f, iterator, initial)
yield reduce(safe_f, iterator, initial)

vals = self.mapPartitions(func).collect()
if vals:
Expand Down Expand Up @@ -911,10 +916,12 @@ def fold(self, zeroValue, op):
>>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
15
"""
safe_op = safe_iter(op)

def func(iterator):
acc = zeroValue
for obj in iterator:
acc = op(acc, obj)
acc = safe_op(acc, obj)
yield acc
# collecting result of mapPartitions here ensures that the copy of
# zeroValue provided to each partition is unique from the one provided
Expand Down Expand Up @@ -943,16 +950,19 @@ def aggregate(self, zeroValue, seqOp, combOp):
>>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
(0, 0)
"""
safe_seqOp = safe_iter(seqOp)
safe_combOp = safe_iter(combOp)

def func(iterator):
acc = zeroValue
for obj in iterator:
acc = seqOp(acc, obj)
acc = safe_seqOp(acc, obj)
yield acc
# collecting result of mapPartitions here ensures that the copy of
# zeroValue provided to each partition is unique from the one provided
# to the final reduce call
vals = self.mapPartitions(func).collect()
return reduce(combOp, vals, zeroValue)
return reduce(safe_combOp, vals, zeroValue)

def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
"""
Expand Down Expand Up @@ -1636,15 +1646,17 @@ def reduceByKeyLocally(self, func):
>>> sorted(rdd.reduceByKeyLocally(add).items())
[('a', 2), ('b', 1)]
"""
safe_func = safe_iter(func)

def reducePartition(iterator):
m = {}
for k, v in iterator:
m[k] = func(m[k], v) if k in m else v
m[k] = safe_func(m[k], v) if k in m else v
yield m

def mergeMaps(m1, m2):
for k, v in m2.items():
m1[k] = func(m1[k], v) if k in m1 else v
m1[k] = safe_func(m1[k], v) if k in m1 else v
return m1
return self.mapPartitions(reducePartition).reduce(mergeMaps)

Expand Down Expand Up @@ -1846,6 +1858,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
>>> sorted(x.combineByKey(to_list, append, extend).collect())
[('a', [1, 2]), ('b', [1])]
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unnecessary change

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's address this comment before merging it in for the backport.

if numPartitions is None:
numPartitions = self._defaultReducePartitions()

Expand Down
19 changes: 16 additions & 3 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ def get_used_memory():
return 0


def safe_iter(f):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about fail_on_StopIteration?

Copy link
Contributor

@cloud-fan cloud-fan May 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also wondering why we put it in shuffle.py instead of util.py

""" wraps f to make it safe (= does not lead to data loss) to use inside a for loop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like this is a potential correctness issue, so the eventual fix for this should be backported to maintenance releases (at least the most recent ones and the next 2.3.x).

I saw the examples provided on the linked JIRAs, but do you have an example of a realistic user workload where this problem can occur (i.e. a case where the problem is more subtle than explicitly throwing StopIteration())? Would that be something like calling next() past the end of an iterator (which I suppose could occur deep in library code)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, indeed that's how I stumbled upon this bug in the first place

make StopIteration's raised inside f explicit
"""
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except StopIteration as exc:
raise RuntimeError('StopIteration in client code', exc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it only a problem for client mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the times, the user function is called within a loop (see e.g. RDD.foreach here). If a StopIteration happens inside that function, pyspark will exit the loop, and will not process the items remaining in that particular partition. This means that some data will disappear, leaving no trace. See this issue for a code example

Copy link
Contributor

@cloud-fan cloud-fan May 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I got that, but why we emphasise "client mode" here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to be explicit and let the user know it's their problem. Can you suggest a better message?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry I misread it as client mode, it's actually client code, LGTM then

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about user code? That's a bit clearer, making it clear that the exception originates in code that was not part of Spark itself but, instead, was written by a user of the Spark framework? Words like client and consumer could be confusing because both are used in other more precise technical clients within our codebase (e.g. client deploy mode, streaming consumer).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, then what about Caught StopIteration in user's code, handle it appropriately?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about Caught StopIteration thrown from user's code; failing the task in order to make it clear that it's expected that the exception bubbles and fails the task?


return wrapper


def _get_local_dirs(sub):
""" Get all the directories """
path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp")
Expand Down Expand Up @@ -94,9 +107,9 @@ class Aggregator(object):
"""

def __init__(self, createCombiner, mergeValue, mergeCombiners):
self.createCombiner = createCombiner
self.mergeValue = mergeValue
self.mergeCombiners = mergeCombiners
self.createCombiner = safe_iter(createCombiner)
self.mergeValue = safe_iter(mergeValue)
self.mergeCombiners = safe_iter(mergeCombiners)


class SimpleAggregator(Aggregator):
Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,19 @@ def __call__(self, x):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

def test_stopiteration_in_udf(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add tests for pandas_udf?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also merge this PR first. I can follow up with the pandas udf tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please, I am not really familiar with UDFs in general

# test for SPARK-23754
from pyspark.sql.functions import udf
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

with self.assertRaises(Py4JJavaError) as cm:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for cm

self.spark.range(0, 1000).withColumn('v', udf(foo)).show()
self.assertIn('StopIteration in client code',
cm.exception.java_exception.toString())

def test_validate_column_types(self):
from pyspark.sql.functions import udf, to_json
from pyspark.sql.column import _to_java_column
Expand Down
71 changes: 71 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,46 @@ def gen_gs(N, step=1):
self.assertEqual(k, len(vs))
self.assertEqual(list(range(k)), list(vs))

def test_stopiteration_is_raised(self):

def validate_exception(exc):
if isinstance(exc, RuntimeError):
self.assertEquals('StopIteration in client code', exc.args[0])
else:
self.assertIn('StopIteration in client code', exc.java_exception.toString())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we don't need this branch, as we always do self.assertRaises((Py4JJavaError, RuntimeError))


def stopit(*args, **kwargs):
raise StopIteration()

def legit_create_combiner(x):
return [x]

def legit_merge_value(x, y):
return x.append(y) or x

def legit_merge_combiners(x, y):
return x.extend(y) or x

data = [(x % 2, x) for x in range(100)]

# wrong create combiner
m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20)
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cm looks unused.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's pick up one explicit exception here too.

Copy link
Member

@viirya viirya Jun 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the change in this test, you still can get a StopIteration in this test. Isn't?

m.mergeValues(data)
validate_exception(cm.exception)

# wrong merge value
m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20)
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
m.mergeValues(data)
validate_exception(cm.exception)

# wrong merge combiners
m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20)
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))
validate_exception(cm.exception)


class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
Expand Down Expand Up @@ -1246,6 +1286,37 @@ def test_pipe_unicode(self):
result = rdd.pipe('cat').collect()
self.assertEqual(data, result)

def test_stopiteration_in_client_code(self):

def a_rdd(keyed=False):
return self.sc.parallelize(
((x % 2, x) if keyed else x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just create two RDDs and reuse it.

for x in range(10)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shell we make this inlined?


def stopit(*x):
raise StopIteration()

def do_test(action, *args, **kwargs):
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we pick up one explicit exception for each?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify?

action(*args, **kwargs)
if isinstance(cm.exception, RuntimeError):
self.assertEquals('StopIteration in client code',
cm.exception.args[0])
else:
self.assertIn('StopIteration in client code',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

cm.exception.java_exception.toString())

do_test(a_rdd().map(stopit).collect)
Copy link
Member

@HyukjinKwon HyukjinKwon May 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could do:

self.assertRaises(RuntimeError, labmda: rdd.map(stopit).collect ... )

do_test(a_rdd().filter(stopit).collect)
do_test(a_rdd().cartesian(a_rdd()).flatMap(stopit).collect)
do_test(a_rdd().foreach, stopit)
do_test(a_rdd(keyed=True).reduceByKeyLocally, stopit)
do_test(a_rdd().reduce, stopit)
do_test(a_rdd().fold, 0, stopit)
do_test(a_rdd().aggregate, 0, stopit, lambda *x: 1)
do_test(a_rdd().aggregate, 0, lambda *x: 1, stopit)


class ProfilerTests(PySparkTestCase):

Expand Down