From ec7854a8504ec08485b3536ea71483cce46f9500 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Mon, 21 May 2018 19:30:10 +0200 Subject: [PATCH 01/14] re-raising StopIteration in client code --- python/pyspark/rdd.py | 35 ++++++++++++------ python/pyspark/shuffle.py | 19 ++++++++-- python/pyspark/sql/tests.py | 13 +++++++ python/pyspark/tests.py | 71 +++++++++++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+), 14 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4b44f76747264..257b435ea7e77 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -49,7 +49,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, ExternalMerger, \ - get_used_memory, ExternalSorter, ExternalGroupBy + get_used_memory, ExternalSorter, ExternalGroupBy, safe_iter from pyspark.traceback_utils import SCCallSiteSync @@ -173,6 +173,7 @@ def ignore_unicode_prefix(f): return f + class Partitioner(object): def __init__(self, numPartitions, partitionFunc): self.numPartitions = numPartitions @@ -332,7 +333,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(f, iterator) + return map(safe_iter(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -347,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(f, iterator)) + return chain.from_iterable(map(safe_iter(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -410,7 +411,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(f, iterator) + return filter(safe_iter(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -791,9 +792,11 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ + safe_f = safe_iter(f) + def processPartition(iterator): for x in iterator: - f(x) + safe_f(x) return iter([]) self.mapPartitions(processPartition).count() # Force evaluation @@ -840,13 +843,15 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ + safe_f = safe_iter(f) + def func(iterator): iterator = iter(iterator) try: initial = next(iterator) except StopIteration: return - yield reduce(f, iterator, initial) + yield reduce(safe_f, iterator, initial) vals = self.mapPartitions(func).collect() if vals: @@ -911,10 +916,12 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ + safe_op = safe_iter(op) + def func(iterator): acc = zeroValue for obj in iterator: - acc = op(acc, obj) + acc = safe_op(acc, obj) yield acc # collecting result of mapPartitions here ensures that the copy of # zeroValue provided to each partition is unique from the one provided @@ -943,16 +950,19 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ + safe_seqOp = safe_iter(seqOp) + safe_combOp = safe_iter(combOp) + def func(iterator): acc = zeroValue for obj in iterator: - acc = seqOp(acc, obj) + acc = safe_seqOp(acc, obj) yield acc # collecting result of mapPartitions here ensures that the copy of # zeroValue provided to each partition is unique from the one provided # to the final reduce call vals = self.mapPartitions(func).collect() - return reduce(combOp, vals, zeroValue) + return reduce(safe_combOp, vals, zeroValue) def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): """ @@ -1636,15 +1646,17 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ + safe_func = safe_iter(func) + def reducePartition(iterator): m = {} for k, v in iterator: - m[k] = func(m[k], v) if k in m else v + m[k] = safe_func(m[k], v) if k in m else v yield m def mergeMaps(m1, m2): for k, v in m2.items(): - m1[k] = func(m1[k], v) if k in m1 else v + m1[k] = safe_func(m1[k], v) if k in m1 else v return m1 return self.mapPartitions(reducePartition).reduce(mergeMaps) @@ -1846,6 +1858,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, >>> sorted(x.combineByKey(to_list, append, extend).collect()) [('a', [1, 2]), ('b', [1])] """ + if numPartitions is None: numPartitions = self._defaultReducePartitions() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 02c773302e9da..7445f66714f03 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -67,6 +67,19 @@ def get_used_memory(): return 0 +def safe_iter(f): + """ wraps f to make it safe (= does not lead to data loss) to use inside a for loop + make StopIteration's raised inside f explicit + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError('StopIteration in client code', exc) + + return wrapper + + def _get_local_dirs(sub): """ Get all the directories """ path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp") @@ -94,9 +107,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = createCombiner - self.mergeValue = mergeValue - self.mergeCombiners = mergeCombiners + self.createCombiner = safe_iter(createCombiner) + self.mergeValue = safe_iter(mergeValue) + self.mergeCombiners = safe_iter(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 16aa9378ad8ee..f651f2b486ca4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,6 +900,19 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + def test_stopiteration_in_udf(self): + # test for SPARK-23754 + from pyspark.sql.functions import udf + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + with self.assertRaises(Py4JJavaError) as cm: + self.spark.range(0, 1000).withColumn('v', udf(foo)).show() + self.assertIn('StopIteration in client code', + cm.exception.java_exception.toString()) + def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 498d6b57e4353..14af8e1fef4bd 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -161,6 +161,46 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) + def test_stopiteration_is_raised(self): + + def validate_exception(exc): + if isinstance(exc, RuntimeError): + self.assertEquals('StopIteration in client code', exc.args[0]) + else: + self.assertIn('StopIteration in client code', exc.java_exception.toString()) + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + validate_exception(cm.exception) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + validate_exception(cm.exception) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + validate_exception(cm.exception) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -1246,6 +1286,37 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) + def test_stopiteration_in_client_code(self): + + def a_rdd(keyed=False): + return self.sc.parallelize( + ((x % 2, x) if keyed else x) + for x in range(10) + ) + + def stopit(*x): + raise StopIteration() + + def do_test(action, *args, **kwargs): + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + action(*args, **kwargs) + if isinstance(cm.exception, RuntimeError): + self.assertEquals('StopIteration in client code', + cm.exception.args[0]) + else: + self.assertIn('StopIteration in client code', + cm.exception.java_exception.toString()) + + do_test(a_rdd().map(stopit).collect) + do_test(a_rdd().filter(stopit).collect) + do_test(a_rdd().cartesian(a_rdd()).flatMap(stopit).collect) + do_test(a_rdd().foreach, stopit) + do_test(a_rdd(keyed=True).reduceByKeyLocally, stopit) + do_test(a_rdd().reduce, stopit) + do_test(a_rdd().fold, 0, stopit) + do_test(a_rdd().aggregate, 0, stopit, lambda *x: 1) + do_test(a_rdd().aggregate, 0, lambda *x: 1, stopit) + class ProfilerTests(PySparkTestCase): From fddd031bbe4dda108739169f0a27eacae8f33099 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Tue, 22 May 2018 16:15:49 +0200 Subject: [PATCH 02/14] moved safe_iter to util module and more descriptive name --- python/pyspark/rdd.py | 22 +++++++++++----------- python/pyspark/shuffle.py | 20 ++++---------------- python/pyspark/util.py | 13 +++++++++++++ 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 257b435ea7e77..184fd1b60cd97 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -49,8 +49,9 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, ExternalMerger, \ - get_used_memory, ExternalSorter, ExternalGroupBy, safe_iter + get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync +from pyspark.util import fail_on_StopIteration __all__ = ["RDD"] @@ -173,7 +174,6 @@ def ignore_unicode_prefix(f): return f - class Partitioner(object): def __init__(self, numPartitions, partitionFunc): self.numPartitions = numPartitions @@ -333,7 +333,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(safe_iter(f), iterator) + return map(fail_on_StopIteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -348,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(safe_iter(f), iterator)) + return chain.from_iterable(map(fail_on_StopIteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -411,7 +411,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(safe_iter(f), iterator) + return filter(fail_on_StopIteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -792,7 +792,7 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ - safe_f = safe_iter(f) + safe_f = fail_on_StopIteration(f) def processPartition(iterator): for x in iterator: @@ -843,7 +843,7 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ - safe_f = safe_iter(f) + safe_f = fail_on_StopIteration(f) def func(iterator): iterator = iter(iterator) @@ -916,7 +916,7 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ - safe_op = safe_iter(op) + safe_op = fail_on_StopIteration(op) def func(iterator): acc = zeroValue @@ -950,8 +950,8 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ - safe_seqOp = safe_iter(seqOp) - safe_combOp = safe_iter(combOp) + safe_seqOp = fail_on_StopIteration(seqOp) + safe_combOp = fail_on_StopIteration(combOp) def func(iterator): acc = zeroValue @@ -1646,7 +1646,7 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ - safe_func = safe_iter(func) + safe_func = fail_on_StopIteration(func) def reducePartition(iterator): m = {} diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 7445f66714f03..250c3233c976f 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,6 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer +from pyspark.util import fail_on_StopIteration try: @@ -67,19 +68,6 @@ def get_used_memory(): return 0 -def safe_iter(f): - """ wraps f to make it safe (= does not lead to data loss) to use inside a for loop - make StopIteration's raised inside f explicit - """ - def wrapper(*args, **kwargs): - try: - return f(*args, **kwargs) - except StopIteration as exc: - raise RuntimeError('StopIteration in client code', exc) - - return wrapper - - def _get_local_dirs(sub): """ Get all the directories """ path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp") @@ -107,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = safe_iter(createCombiner) - self.mergeValue = safe_iter(mergeValue) - self.mergeCombiners = safe_iter(mergeCombiners) + self.createCombiner = fail_on_StopIteration(createCombiner) + self.mergeValue = fail_on_StopIteration(mergeValue) + self.mergeCombiners = fail_on_StopIteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 59cc2a6329350..5807fde0812f8 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -89,6 +89,19 @@ def majorMinorVersion(sparkVersion): " version numbers.") +def fail_on_StopIteration(f): + """ wraps f to make it safe (= does not lead to data loss) to use inside a for loop + make StopIteration's raised inside f explicit + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError('StopIteration in client code', exc) + + return wrapper + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() From ee54924b9d23e616d432497c77e46671ad15ef88 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Tue, 22 May 2018 16:16:11 +0200 Subject: [PATCH 03/14] removed redundancy from tests --- python/pyspark/sql/tests.py | 2 -- python/pyspark/tests.py | 15 --------------- 2 files changed, 17 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f651f2b486ca4..f66423e05e9b6 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -910,8 +910,6 @@ def foo(x): with self.assertRaises(Py4JJavaError) as cm: self.spark.range(0, 1000).withColumn('v', udf(foo)).show() - self.assertIn('StopIteration in client code', - cm.exception.java_exception.toString()) def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 14af8e1fef4bd..383fdde59aad0 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -163,12 +163,6 @@ def gen_gs(N, step=1): def test_stopiteration_is_raised(self): - def validate_exception(exc): - if isinstance(exc, RuntimeError): - self.assertEquals('StopIteration in client code', exc.args[0]) - else: - self.assertIn('StopIteration in client code', exc.java_exception.toString()) - def stopit(*args, **kwargs): raise StopIteration() @@ -187,19 +181,16 @@ def legit_merge_combiners(x, y): m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: m.mergeValues(data) - validate_exception(cm.exception) # wrong merge value m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: m.mergeValues(data) - validate_exception(cm.exception) # wrong merge combiners m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) - validate_exception(cm.exception) class SorterTests(unittest.TestCase): @@ -1300,12 +1291,6 @@ def stopit(*x): def do_test(action, *args, **kwargs): with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: action(*args, **kwargs) - if isinstance(cm.exception, RuntimeError): - self.assertEquals('StopIteration in client code', - cm.exception.args[0]) - else: - self.assertIn('StopIteration in client code', - cm.exception.java_exception.toString()) do_test(a_rdd().map(stopit).collect) do_test(a_rdd().filter(stopit).collect) From d739eea9e8ed07dad9dd9b1a795ff21e8f915694 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Thu, 24 May 2018 15:16:58 +0200 Subject: [PATCH 04/14] improved doc, error message and code style --- python/pyspark/rdd.py | 35 +++++++++++++++++------------------ python/pyspark/shuffle.py | 8 ++++---- python/pyspark/util.py | 9 +++++---- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 184fd1b60cd97..ac127ac5d61c1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -51,7 +51,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync -from pyspark.util import fail_on_StopIteration +from pyspark.util import fail_on_stopiteration __all__ = ["RDD"] @@ -333,7 +333,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(fail_on_StopIteration(f), iterator) + return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -348,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(fail_on_StopIteration(f), iterator)) + return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -411,7 +411,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(fail_on_StopIteration(f), iterator) + return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -792,11 +792,11 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ - safe_f = fail_on_StopIteration(f) + f = fail_on_stopiteration(f) def processPartition(iterator): for x in iterator: - safe_f(x) + f(x) return iter([]) self.mapPartitions(processPartition).count() # Force evaluation @@ -843,7 +843,7 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ - safe_f = fail_on_StopIteration(f) + f = fail_on_stopiteration(f) def func(iterator): iterator = iter(iterator) @@ -851,7 +851,7 @@ def func(iterator): initial = next(iterator) except StopIteration: return - yield reduce(safe_f, iterator, initial) + yield reduce(f, iterator, initial) vals = self.mapPartitions(func).collect() if vals: @@ -916,12 +916,12 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ - safe_op = fail_on_StopIteration(op) + op = fail_on_stopiteration(op) def func(iterator): acc = zeroValue for obj in iterator: - acc = safe_op(acc, obj) + acc = op(acc, obj) yield acc # collecting result of mapPartitions here ensures that the copy of # zeroValue provided to each partition is unique from the one provided @@ -950,19 +950,19 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ - safe_seqOp = fail_on_StopIteration(seqOp) - safe_combOp = fail_on_StopIteration(combOp) + seqOp = fail_on_stopiteration(seqOp) + combOp = fail_on_stopiteration(combOp) def func(iterator): acc = zeroValue for obj in iterator: - acc = safe_seqOp(acc, obj) + acc = seqOp(acc, obj) yield acc # collecting result of mapPartitions here ensures that the copy of # zeroValue provided to each partition is unique from the one provided # to the final reduce call vals = self.mapPartitions(func).collect() - return reduce(safe_combOp, vals, zeroValue) + return reduce(combOp, vals, zeroValue) def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): """ @@ -1646,17 +1646,17 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ - safe_func = fail_on_StopIteration(func) + func = fail_on_stopiteration(func) def reducePartition(iterator): m = {} for k, v in iterator: - m[k] = safe_func(m[k], v) if k in m else v + m[k] = func(m[k], v) if k in m else v yield m def mergeMaps(m1, m2): for k, v in m2.items(): - m1[k] = safe_func(m1[k], v) if k in m1 else v + m1[k] = func(m1[k], v) if k in m1 else v return m1 return self.mapPartitions(reducePartition).reduce(mergeMaps) @@ -1858,7 +1858,6 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, >>> sorted(x.combineByKey(to_list, append, extend).collect()) [('a', [1, 2]), ('b', [1])] """ - if numPartitions is None: numPartitions = self._defaultReducePartitions() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 250c3233c976f..bd0ac0039ffe1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,7 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer -from pyspark.util import fail_on_StopIteration +from pyspark.util import fail_on_stopiteration try: @@ -95,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = fail_on_StopIteration(createCombiner) - self.mergeValue = fail_on_StopIteration(mergeValue) - self.mergeCombiners = fail_on_StopIteration(mergeCombiners) + self.createCombiner = fail_on_stopiteration(createCombiner) + self.mergeValue = fail_on_stopiteration(mergeValue) + self.mergeCombiners = fail_on_stopiteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 5807fde0812f8..e77a40e8b808f 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -89,15 +89,16 @@ def majorMinorVersion(sparkVersion): " version numbers.") -def fail_on_StopIteration(f): - """ wraps f to make it safe (= does not lead to data loss) to use inside a for loop - make StopIteration's raised inside f explicit +def fail_on_stopiteration(f): + """ + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop """ def wrapper(*args, **kwargs): try: return f(*args, **kwargs) except StopIteration as exc: - raise RuntimeError('StopIteration in client code', exc) + raise RuntimeError("Caught StopIteration thrown from user's code; failing the task", exc) return wrapper From f0f80ed1b8333bbab841a59f151deff18bc73447 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Thu, 24 May 2018 15:17:46 +0200 Subject: [PATCH 05/14] improved tests --- python/pyspark/sql/tests.py | 2 +- python/pyspark/tests.py | 32 +++++++++++++------------------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f66423e05e9b6..fb592989ce4b0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -908,7 +908,7 @@ def test_stopiteration_in_udf(self): def foo(x): raise StopIteration() - with self.assertRaises(Py4JJavaError) as cm: + with self.assertRaises(Py4JJavaError): self.spark.range(0, 1000).withColumn('v', udf(foo)).show() def test_validate_column_types(self): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 383fdde59aad0..18f88cee89e1f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1279,28 +1279,22 @@ def test_pipe_unicode(self): def test_stopiteration_in_client_code(self): - def a_rdd(keyed=False): - return self.sc.parallelize( - ((x % 2, x) if keyed else x) - for x in range(10) - ) - def stopit(*x): raise StopIteration() - def do_test(action, *args, **kwargs): - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - action(*args, **kwargs) - - do_test(a_rdd().map(stopit).collect) - do_test(a_rdd().filter(stopit).collect) - do_test(a_rdd().cartesian(a_rdd()).flatMap(stopit).collect) - do_test(a_rdd().foreach, stopit) - do_test(a_rdd(keyed=True).reduceByKeyLocally, stopit) - do_test(a_rdd().reduce, stopit) - do_test(a_rdd().fold, 0, stopit) - do_test(a_rdd().aggregate, 0, stopit, lambda *x: 1) - do_test(a_rdd().aggregate, 0, lambda *x: 1, stopit) + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + exc = Py4JJavaError, RuntimeError + + self.assertRaises(exc, seq_rdd.map(stopit).collect) + self.assertRaises(exc, seq_rdd.filter(stopit).collect) + self.assertRaises(exc, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + self.assertRaises(exc, seq_rdd.foreach, stopit) + self.assertRaises(exc, keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaises(exc, seq_rdd.reduce, stopit) + self.assertRaises(exc, seq_rdd.fold, 0, stopit) + self.assertRaises(exc, seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaises(exc, seq_rdd.aggregate, 0, lambda *x: 1, stopit) class ProfilerTests(PySparkTestCase): From d59f0d5a2735713bb7e218cfcda2b494edfcf522 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Thu, 24 May 2018 15:37:53 +0200 Subject: [PATCH 06/14] fixed style --- python/pyspark/util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index e77a40e8b808f..938e729260bba 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -98,7 +98,10 @@ def wrapper(*args, **kwargs): try: return f(*args, **kwargs) except StopIteration as exc: - raise RuntimeError("Caught StopIteration thrown from user's code; failing the task", exc) + raise RuntimeError( + "Caught StopIteration thrown from user's code; failing the task", + exc + ) return wrapper From b0af18e400c01095dd87589260ce80e9712a9f07 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Thu, 24 May 2018 18:44:38 +0200 Subject: [PATCH 07/14] fixed udf and its test --- python/pyspark/sql/tests.py | 2 +- python/pyspark/sql/udf.py | 4 ++-- python/pyspark/util.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fb592989ce4b0..53d6dff9eb1c4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -909,7 +909,7 @@ def foo(x): raise StopIteration() with self.assertRaises(Py4JJavaError): - self.spark.range(0, 1000).withColumn('v', udf(foo)).show() + self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 9dbe49b831cef..f41e307d6992e 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -25,7 +25,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ to_arrow_type, to_arrow_schema -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration __all__ = ["UDFRegistration"] @@ -92,7 +92,7 @@ def __init__(self, func, raise TypeError( "Invalid evalType: evalType should be an int but is {}".format(evalType)) - self.func = func + self.func = fail_on_stopiteration(func) self._returnType = returnType # Stores UserDefinedPythonFunctions jobj, once initialized self._returnType_placeholder = None diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 938e729260bba..fa1b1c2da0b21 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -91,8 +91,8 @@ def majorMinorVersion(sparkVersion): def fail_on_stopiteration(f): """ - Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' - prevents silent loss of data when 'f' is used in a for loop + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop """ def wrapper(*args, **kwargs): try: From 167a75b81599e176f851daa2566b359f72264f61 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Thu, 24 May 2018 20:18:40 +0200 Subject: [PATCH 08/14] preserving metadata of wrapped function --- python/pyspark/sql/udf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index f41e307d6992e..b35b4bbdb5812 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -92,7 +92,7 @@ def __init__(self, func, raise TypeError( "Invalid evalType: evalType should be an int but is {}".format(evalType)) - self.func = fail_on_stopiteration(func) + self.func = func self._returnType = returnType # Stores UserDefinedPythonFunctions jobj, once initialized self._returnType_placeholder = None @@ -157,7 +157,7 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, self.func, self.returnType) + wrapped_func = _wrap_function(sc, fail_on_stopiteration(self.func), self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) From 90b064ddd2db562e90dd55846f1331779e795460 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Thu, 24 May 2018 20:19:16 +0200 Subject: [PATCH 09/14] catching relevant exceptions only --- python/pyspark/tests.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 18f88cee89e1f..3b37cc028c1b7 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1284,17 +1284,20 @@ def stopit(*x): seq_rdd = self.sc.parallelize(range(10)) keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) - exc = Py4JJavaError, RuntimeError - - self.assertRaises(exc, seq_rdd.map(stopit).collect) - self.assertRaises(exc, seq_rdd.filter(stopit).collect) - self.assertRaises(exc, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - self.assertRaises(exc, seq_rdd.foreach, stopit) - self.assertRaises(exc, keyed_rdd.reduceByKeyLocally, stopit) - self.assertRaises(exc, seq_rdd.reduce, stopit) - self.assertRaises(exc, seq_rdd.fold, 0, stopit) - self.assertRaises(exc, seq_rdd.aggregate, 0, stopit, lambda *x: 1) - self.assertRaises(exc, seq_rdd.aggregate, 0, lambda *x: 1, stopit) + + self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) + self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) + + # the exception raised is non-deterministic + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, lambda *x: 1, stopit) class ProfilerTests(PySparkTestCase): From 75316af5f366d9c0386a9396fc981a9294541cb0 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Sat, 26 May 2018 16:56:01 +0200 Subject: [PATCH 10/14] preserving argspecs of wrapped function --- python/pyspark/sql/tests.py | 8 +++++++- python/pyspark/util.py | 16 +++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 53d6dff9eb1c4..e5b8fde4cd94f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -908,9 +908,15 @@ def test_stopiteration_in_udf(self): def foo(x): raise StopIteration() - with self.assertRaises(Py4JJavaError): + with self.assertRaises(Py4JJavaError) as cm: self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() + self.assertIn( + "Caught StopIteration thrown from user's code; failing the task", + cm.exception.java_exception.toString() + ) + + def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column diff --git a/python/pyspark/util.py b/python/pyspark/util.py index fa1b1c2da0b21..178dd94034414 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -23,6 +23,8 @@ __all__ = [] +WRAPPED_ARGSPEC_ATTR = '_wrapped_argspec' + def _exception_message(excp): """Return the message from an exception as either a str or unicode object. Supports both @@ -55,7 +57,9 @@ def _get_argspec(f): """ # `getargspec` is deprecated since python3.0 (incompatible with function annotations). # See SPARK-23569. - if sys.version_info[0] < 3: + if hasattr(f, WRAPPED_ARGSPEC_ATTR): + argspec = getattr(f, WRAPPED_ARGSPEC_ATTR) + elif sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: argspec = inspect.getfullargspec(f) @@ -103,6 +107,16 @@ def wrapper(*args, **kwargs): exc ) + # prevent inspect to fail + # e.g. inspect.getargspec(sum) raises + # TypeError: is not a Python function + try: + argspec = _get_argspec(f) + except TypeError: + pass + else: + setattr(wrapper, WRAPPED_ARGSPEC_ATTR, _get_argspec(f)) + return wrapper From 026ecddacb847d624cd53150e82c011b6befafc0 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Sat, 26 May 2018 17:23:32 +0200 Subject: [PATCH 11/14] style --- python/pyspark/sql/tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e5b8fde4cd94f..1f91d2c181685 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -916,7 +916,6 @@ def foo(x): cm.exception.java_exception.toString() ) - def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column From f7b53c222e4341b59d3588017718d80ecb37a473 Mon Sep 17 00:00:00 2001 From: e-dorigatti Date: Tue, 29 May 2018 09:43:08 +0200 Subject: [PATCH 12/14] saving argspec in udf --- python/pyspark/sql/udf.py | 12 +++++++++++- python/pyspark/util.py | 16 ++-------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index b35b4bbdb5812..fb71fa1f60f14 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -157,7 +157,17 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, fail_on_stopiteration(self.func), self.returnType) + func = fail_on_stopiteration(self.func) + + # prevent inspect to fail + # e.g. inspect.getargspec(sum) raises + # TypeError: is not a Python function + try: + func._argspec = _get_argspec(self.func) + except TypeError: + pass + + wrapped_func = _wrap_function(sc, func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 178dd94034414..8b646d573f207 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -23,8 +23,6 @@ __all__ = [] -WRAPPED_ARGSPEC_ATTR = '_wrapped_argspec' - def _exception_message(excp): """Return the message from an exception as either a str or unicode object. Supports both @@ -57,8 +55,8 @@ def _get_argspec(f): """ # `getargspec` is deprecated since python3.0 (incompatible with function annotations). # See SPARK-23569. - if hasattr(f, WRAPPED_ARGSPEC_ATTR): - argspec = getattr(f, WRAPPED_ARGSPEC_ATTR) + if hasattr(f, '_argspec'): + argspec = f._argspec elif sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: @@ -107,16 +105,6 @@ def wrapper(*args, **kwargs): exc ) - # prevent inspect to fail - # e.g. inspect.getargspec(sum) raises - # TypeError: is not a Python function - try: - argspec = _get_argspec(f) - except TypeError: - pass - else: - setattr(wrapper, WRAPPED_ARGSPEC_ATTR, _get_argspec(f)) - return wrapper From 8fac2a80deb79030dee161e0d86b7b090bc892a7 Mon Sep 17 00:00:00 2001 From: edorigatti Date: Tue, 29 May 2018 17:02:30 +0200 Subject: [PATCH 13/14] saving signature only for pandas udf, removed useless try/except --- python/pyspark/sql/udf.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index fb71fa1f60f14..c8fb49d7c2b65 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -159,13 +159,13 @@ def _create_judf(self): func = fail_on_stopiteration(self.func) - # prevent inspect to fail - # e.g. inspect.getargspec(sum) raises - # TypeError: is not a Python function - try: + # for pandas UDFs the worker needs to know if the function takes + # one or two arguments, but the signature is lost when wrapping with + # fail_on_stopiteration, so we store it here + if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): func._argspec = _get_argspec(self.func) - except TypeError: - pass wrapped_func = _wrap_function(sc, func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) From 5b5570b7d4a4e71d470dbb9e763b50a948d4195c Mon Sep 17 00:00:00 2001 From: edorigatti Date: Wed, 30 May 2018 10:30:09 +0200 Subject: [PATCH 14/14] comment explaining hack --- python/pyspark/util.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 8b646d573f207..e95a9b523393f 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -53,13 +53,16 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. + if hasattr(f, '_argspec'): + # only used for pandas UDF: they wrap the user function, losing its signature + # workers need this signature, so UDF saves it here argspec = f._argspec elif sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. argspec = inspect.getfullargspec(f) return argspec