diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 184fd1b60cd97..ac127ac5d61c1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -51,7 +51,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync -from pyspark.util import fail_on_StopIteration +from pyspark.util import fail_on_stopiteration __all__ = ["RDD"] @@ -333,7 +333,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(fail_on_StopIteration(f), iterator) + return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -348,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(fail_on_StopIteration(f), iterator)) + return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -411,7 +411,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(fail_on_StopIteration(f), iterator) + return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -792,11 +792,11 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ - safe_f = fail_on_StopIteration(f) + f = fail_on_stopiteration(f) def processPartition(iterator): for x in iterator: - safe_f(x) + f(x) return iter([]) self.mapPartitions(processPartition).count() # Force evaluation @@ -843,7 +843,7 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ - safe_f = fail_on_StopIteration(f) + f = fail_on_stopiteration(f) def func(iterator): iterator = iter(iterator) @@ -851,7 +851,7 @@ def func(iterator): initial = next(iterator) except StopIteration: return - yield reduce(safe_f, iterator, initial) + yield reduce(f, iterator, initial) vals = self.mapPartitions(func).collect() if vals: @@ -916,12 +916,12 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ - safe_op = fail_on_StopIteration(op) + op = fail_on_stopiteration(op) def func(iterator): acc = zeroValue for obj in iterator: - acc = safe_op(acc, obj) + acc = op(acc, obj) yield acc # collecting result of mapPartitions here ensures that the copy of # zeroValue provided to each partition is unique from the one provided @@ -950,19 +950,19 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ - safe_seqOp = fail_on_StopIteration(seqOp) - safe_combOp = fail_on_StopIteration(combOp) + seqOp = fail_on_stopiteration(seqOp) + combOp = fail_on_stopiteration(combOp) def func(iterator): acc = zeroValue for obj in iterator: - acc = safe_seqOp(acc, obj) + acc = seqOp(acc, obj) yield acc # collecting result of mapPartitions here ensures that the copy of # zeroValue provided to each partition is unique from the one provided # to the final reduce call vals = self.mapPartitions(func).collect() - return reduce(safe_combOp, vals, zeroValue) + return reduce(combOp, vals, zeroValue) def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): """ @@ -1646,17 +1646,17 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ - safe_func = fail_on_StopIteration(func) + func = fail_on_stopiteration(func) def reducePartition(iterator): m = {} for k, v in iterator: - m[k] = safe_func(m[k], v) if k in m else v + m[k] = func(m[k], v) if k in m else v yield m def mergeMaps(m1, m2): for k, v in m2.items(): - m1[k] = safe_func(m1[k], v) if k in m1 else v + m1[k] = func(m1[k], v) if k in m1 else v return m1 return self.mapPartitions(reducePartition).reduce(mergeMaps) @@ -1858,7 +1858,6 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, >>> sorted(x.combineByKey(to_list, append, extend).collect()) [('a', [1, 2]), ('b', [1])] """ - if numPartitions is None: numPartitions = self._defaultReducePartitions() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 250c3233c976f..bd0ac0039ffe1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,7 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer -from pyspark.util import fail_on_StopIteration +from pyspark.util import fail_on_stopiteration try: @@ -95,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = fail_on_StopIteration(createCombiner) - self.mergeValue = fail_on_StopIteration(mergeValue) - self.mergeCombiners = fail_on_StopIteration(mergeCombiners) + self.createCombiner = fail_on_stopiteration(createCombiner) + self.mergeValue = fail_on_stopiteration(mergeValue) + self.mergeCombiners = fail_on_stopiteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 5807fde0812f8..e77a40e8b808f 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -89,15 +89,16 @@ def majorMinorVersion(sparkVersion): " version numbers.") -def fail_on_StopIteration(f): - """ wraps f to make it safe (= does not lead to data loss) to use inside a for loop - make StopIteration's raised inside f explicit +def fail_on_stopiteration(f): + """ + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop """ def wrapper(*args, **kwargs): try: return f(*args, **kwargs) except StopIteration as exc: - raise RuntimeError('StopIteration in client code', exc) + raise RuntimeError("Caught StopIteration thrown from user's code; failing the task", exc) return wrapper