-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23754][Python] Re-raising StopIteration in client code #21383
Changes from 1 commit
ec7854a
fddd031
ee54924
d739eea
f0f80ed
d59f0d5
b0af18e
167a75b
90b064d
75316af
026ecdd
f7b53c2
8fac2a8
5b5570b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,7 +49,7 @@ | |
from pyspark.storagelevel import StorageLevel | ||
from pyspark.resultiterable import ResultIterable | ||
from pyspark.shuffle import Aggregator, ExternalMerger, \ | ||
get_used_memory, ExternalSorter, ExternalGroupBy | ||
get_used_memory, ExternalSorter, ExternalGroupBy, safe_iter | ||
from pyspark.traceback_utils import SCCallSiteSync | ||
|
||
|
||
|
@@ -173,6 +173,7 @@ def ignore_unicode_prefix(f): | |
return f | ||
|
||
|
||
|
||
class Partitioner(object): | ||
def __init__(self, numPartitions, partitionFunc): | ||
self.numPartitions = numPartitions | ||
|
@@ -332,7 +333,7 @@ def map(self, f, preservesPartitioning=False): | |
[('a', 1), ('b', 1), ('c', 1)] | ||
""" | ||
def func(_, iterator): | ||
return map(f, iterator) | ||
return map(safe_iter(f), iterator) | ||
return self.mapPartitionsWithIndex(func, preservesPartitioning) | ||
|
||
def flatMap(self, f, preservesPartitioning=False): | ||
|
@@ -347,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False): | |
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] | ||
""" | ||
def func(s, iterator): | ||
return chain.from_iterable(map(f, iterator)) | ||
return chain.from_iterable(map(safe_iter(f), iterator)) | ||
return self.mapPartitionsWithIndex(func, preservesPartitioning) | ||
|
||
def mapPartitions(self, f, preservesPartitioning=False): | ||
|
@@ -410,7 +411,7 @@ def filter(self, f): | |
[2, 4] | ||
""" | ||
def func(iterator): | ||
return filter(f, iterator) | ||
return filter(safe_iter(f), iterator) | ||
return self.mapPartitions(func, True) | ||
|
||
def distinct(self, numPartitions=None): | ||
|
@@ -791,9 +792,11 @@ def foreach(self, f): | |
>>> def f(x): print(x) | ||
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) | ||
""" | ||
safe_f = safe_iter(f) | ||
|
||
def processPartition(iterator): | ||
for x in iterator: | ||
f(x) | ||
safe_f(x) | ||
return iter([]) | ||
self.mapPartitions(processPartition).count() # Force evaluation | ||
|
||
|
@@ -840,13 +843,15 @@ def reduce(self, f): | |
... | ||
ValueError: Can not reduce() empty RDD | ||
""" | ||
safe_f = safe_iter(f) | ||
|
||
def func(iterator): | ||
iterator = iter(iterator) | ||
try: | ||
initial = next(iterator) | ||
except StopIteration: | ||
return | ||
yield reduce(f, iterator, initial) | ||
yield reduce(safe_f, iterator, initial) | ||
|
||
vals = self.mapPartitions(func).collect() | ||
if vals: | ||
|
@@ -911,10 +916,12 @@ def fold(self, zeroValue, op): | |
>>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) | ||
15 | ||
""" | ||
safe_op = safe_iter(op) | ||
|
||
def func(iterator): | ||
acc = zeroValue | ||
for obj in iterator: | ||
acc = op(acc, obj) | ||
acc = safe_op(acc, obj) | ||
yield acc | ||
# collecting result of mapPartitions here ensures that the copy of | ||
# zeroValue provided to each partition is unique from the one provided | ||
|
@@ -943,16 +950,19 @@ def aggregate(self, zeroValue, seqOp, combOp): | |
>>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) | ||
(0, 0) | ||
""" | ||
safe_seqOp = safe_iter(seqOp) | ||
safe_combOp = safe_iter(combOp) | ||
|
||
def func(iterator): | ||
acc = zeroValue | ||
for obj in iterator: | ||
acc = seqOp(acc, obj) | ||
acc = safe_seqOp(acc, obj) | ||
yield acc | ||
# collecting result of mapPartitions here ensures that the copy of | ||
# zeroValue provided to each partition is unique from the one provided | ||
# to the final reduce call | ||
vals = self.mapPartitions(func).collect() | ||
return reduce(combOp, vals, zeroValue) | ||
return reduce(safe_combOp, vals, zeroValue) | ||
|
||
def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): | ||
""" | ||
|
@@ -1636,15 +1646,17 @@ def reduceByKeyLocally(self, func): | |
>>> sorted(rdd.reduceByKeyLocally(add).items()) | ||
[('a', 2), ('b', 1)] | ||
""" | ||
safe_func = safe_iter(func) | ||
|
||
def reducePartition(iterator): | ||
m = {} | ||
for k, v in iterator: | ||
m[k] = func(m[k], v) if k in m else v | ||
m[k] = safe_func(m[k], v) if k in m else v | ||
yield m | ||
|
||
def mergeMaps(m1, m2): | ||
for k, v in m2.items(): | ||
m1[k] = func(m1[k], v) if k in m1 else v | ||
m1[k] = safe_func(m1[k], v) if k in m1 else v | ||
return m1 | ||
return self.mapPartitions(reducePartition).reduce(mergeMaps) | ||
|
||
|
@@ -1846,6 +1858,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, | |
>>> sorted(x.combineByKey(to_list, append, extend).collect()) | ||
[('a', [1, 2]), ('b', [1])] | ||
""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: unnecessary change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's address this comment before merging it in for the backport. |
||
if numPartitions is None: | ||
numPartitions = self._defaultReducePartitions() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,6 +67,19 @@ def get_used_memory(): | |
return 0 | ||
|
||
|
||
def safe_iter(f): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm also wondering why we put it in |
||
""" wraps f to make it safe (= does not lead to data loss) to use inside a for loop | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It sounds like this is a potential correctness issue, so the eventual fix for this should be backported to maintenance releases (at least the most recent ones and the next 2.3.x). I saw the examples provided on the linked JIRAs, but do you have an example of a realistic user workload where this problem can occur (i.e. a case where the problem is more subtle than explicitly throwing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, indeed that's how I stumbled upon this bug in the first place |
||
make StopIteration's raised inside f explicit | ||
""" | ||
def wrapper(*args, **kwargs): | ||
try: | ||
return f(*args, **kwargs) | ||
except StopIteration as exc: | ||
raise RuntimeError('StopIteration in client code', exc) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is it only a problem for client mode? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most of the times, the user function is called within a loop (see e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea I got that, but why we emphasise "client mode" here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted to be explicit and let the user know it's their problem. Can you suggest a better message? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh sorry I misread it as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, then what about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about |
||
|
||
return wrapper | ||
|
||
|
||
def _get_local_dirs(sub): | ||
""" Get all the directories """ | ||
path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp") | ||
|
@@ -94,9 +107,9 @@ class Aggregator(object): | |
""" | ||
|
||
def __init__(self, createCombiner, mergeValue, mergeCombiners): | ||
self.createCombiner = createCombiner | ||
self.mergeValue = mergeValue | ||
self.mergeCombiners = mergeCombiners | ||
self.createCombiner = safe_iter(createCombiner) | ||
self.mergeValue = safe_iter(mergeValue) | ||
self.mergeCombiners = safe_iter(mergeCombiners) | ||
|
||
|
||
class SimpleAggregator(Aggregator): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -900,6 +900,19 @@ def __call__(self, x): | |
self.assertEqual(f, f_.func) | ||
self.assertEqual(return_type, f_.returnType) | ||
|
||
def test_stopiteration_in_udf(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we also add tests for pandas_udf? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can also merge this PR first. I can follow up with the pandas udf tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes please, I am not really familiar with UDFs in general |
||
# test for SPARK-23754 | ||
from pyspark.sql.functions import udf | ||
from py4j.protocol import Py4JJavaError | ||
|
||
def foo(x): | ||
raise StopIteration() | ||
|
||
with self.assertRaises(Py4JJavaError) as cm: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto for |
||
self.spark.range(0, 1000).withColumn('v', udf(foo)).show() | ||
self.assertIn('StopIteration in client code', | ||
cm.exception.java_exception.toString()) | ||
|
||
def test_validate_column_types(self): | ||
from pyspark.sql.functions import udf, to_json | ||
from pyspark.sql.column import _to_java_column | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,6 +161,46 @@ def gen_gs(N, step=1): | |
self.assertEqual(k, len(vs)) | ||
self.assertEqual(list(range(k)), list(vs)) | ||
|
||
def test_stopiteration_is_raised(self): | ||
|
||
def validate_exception(exc): | ||
if isinstance(exc, RuntimeError): | ||
self.assertEquals('StopIteration in client code', exc.args[0]) | ||
else: | ||
self.assertIn('StopIteration in client code', exc.java_exception.toString()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems we don't need this branch, as we always do |
||
|
||
def stopit(*args, **kwargs): | ||
raise StopIteration() | ||
|
||
def legit_create_combiner(x): | ||
return [x] | ||
|
||
def legit_merge_value(x, y): | ||
return x.append(y) or x | ||
|
||
def legit_merge_combiners(x, y): | ||
return x.extend(y) or x | ||
|
||
data = [(x % 2, x) for x in range(100)] | ||
|
||
# wrong create combiner | ||
m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) | ||
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's pick up one explicit exception here too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without the change in this test, you still can get a |
||
m.mergeValues(data) | ||
validate_exception(cm.exception) | ||
|
||
# wrong merge value | ||
m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) | ||
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: | ||
m.mergeValues(data) | ||
validate_exception(cm.exception) | ||
|
||
# wrong merge combiners | ||
m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) | ||
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: | ||
m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) | ||
validate_exception(cm.exception) | ||
|
||
|
||
class SorterTests(unittest.TestCase): | ||
def test_in_memory_sort(self): | ||
|
@@ -1246,6 +1286,37 @@ def test_pipe_unicode(self): | |
result = rdd.pipe('cat').collect() | ||
self.assertEqual(data, result) | ||
|
||
def test_stopiteration_in_client_code(self): | ||
|
||
def a_rdd(keyed=False): | ||
return self.sc.parallelize( | ||
((x % 2, x) if keyed else x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would just create two RDDs and reuse it. |
||
for x in range(10) | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shell we make this inlined? |
||
|
||
def stopit(*x): | ||
raise StopIteration() | ||
|
||
def do_test(action, *args, **kwargs): | ||
with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we pick up one explicit exception for each? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you clarify? |
||
action(*args, **kwargs) | ||
if isinstance(cm.exception, RuntimeError): | ||
self.assertEquals('StopIteration in client code', | ||
cm.exception.args[0]) | ||
else: | ||
self.assertIn('StopIteration in client code', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
cm.exception.java_exception.toString()) | ||
|
||
do_test(a_rdd().map(stopit).collect) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could do:
|
||
do_test(a_rdd().filter(stopit).collect) | ||
do_test(a_rdd().cartesian(a_rdd()).flatMap(stopit).collect) | ||
do_test(a_rdd().foreach, stopit) | ||
do_test(a_rdd(keyed=True).reduceByKeyLocally, stopit) | ||
do_test(a_rdd().reduce, stopit) | ||
do_test(a_rdd().fold, 0, stopit) | ||
do_test(a_rdd().aggregate, 0, stopit, lambda *x: 1) | ||
do_test(a_rdd().aggregate, 0, lambda *x: 1, stopit) | ||
|
||
|
||
class ProfilerTests(PySparkTestCase): | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, I would revert unrelated changes to make the backport easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not really familiar with the codebase, can you provide more details please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this new line is an unnecessary change