Skip to content

Commit

Permalink
improved doc, error message and code style
Browse files Browse the repository at this point in the history
  • Loading branch information
e-dorigatti committed May 24, 2018
1 parent ee54924 commit d739eea
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 26 deletions.
35 changes: 17 additions & 18 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from pyspark.shuffle import Aggregator, ExternalMerger, \
get_used_memory, ExternalSorter, ExternalGroupBy
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.util import fail_on_StopIteration
from pyspark.util import fail_on_stopiteration


__all__ = ["RDD"]
Expand Down Expand Up @@ -333,7 +333,7 @@ def map(self, f, preservesPartitioning=False):
[('a', 1), ('b', 1), ('c', 1)]
"""
def func(_, iterator):
return map(fail_on_StopIteration(f), iterator)
return map(fail_on_stopiteration(f), iterator)
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def flatMap(self, f, preservesPartitioning=False):
Expand All @@ -348,7 +348,7 @@ def flatMap(self, f, preservesPartitioning=False):
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
def func(s, iterator):
return chain.from_iterable(map(fail_on_StopIteration(f), iterator))
return chain.from_iterable(map(fail_on_stopiteration(f), iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)

def mapPartitions(self, f, preservesPartitioning=False):
Expand Down Expand Up @@ -411,7 +411,7 @@ def filter(self, f):
[2, 4]
"""
def func(iterator):
return filter(fail_on_StopIteration(f), iterator)
return filter(fail_on_stopiteration(f), iterator)
return self.mapPartitions(func, True)

def distinct(self, numPartitions=None):
Expand Down Expand Up @@ -792,11 +792,11 @@ def foreach(self, f):
>>> def f(x): print(x)
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
safe_f = fail_on_StopIteration(f)
f = fail_on_stopiteration(f)

def processPartition(iterator):
for x in iterator:
safe_f(x)
f(x)
return iter([])
self.mapPartitions(processPartition).count() # Force evaluation

Expand Down Expand Up @@ -843,15 +843,15 @@ def reduce(self, f):
...
ValueError: Can not reduce() empty RDD
"""
safe_f = fail_on_StopIteration(f)
f = fail_on_stopiteration(f)

def func(iterator):
iterator = iter(iterator)
try:
initial = next(iterator)
except StopIteration:
return
yield reduce(safe_f, iterator, initial)
yield reduce(f, iterator, initial)

vals = self.mapPartitions(func).collect()
if vals:
Expand Down Expand Up @@ -916,12 +916,12 @@ def fold(self, zeroValue, op):
>>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
15
"""
safe_op = fail_on_StopIteration(op)
op = fail_on_stopiteration(op)

def func(iterator):
acc = zeroValue
for obj in iterator:
acc = safe_op(acc, obj)
acc = op(acc, obj)
yield acc
# collecting result of mapPartitions here ensures that the copy of
# zeroValue provided to each partition is unique from the one provided
Expand Down Expand Up @@ -950,19 +950,19 @@ def aggregate(self, zeroValue, seqOp, combOp):
>>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp)
(0, 0)
"""
safe_seqOp = fail_on_StopIteration(seqOp)
safe_combOp = fail_on_StopIteration(combOp)
seqOp = fail_on_stopiteration(seqOp)
combOp = fail_on_stopiteration(combOp)

def func(iterator):
acc = zeroValue
for obj in iterator:
acc = safe_seqOp(acc, obj)
acc = seqOp(acc, obj)
yield acc
# collecting result of mapPartitions here ensures that the copy of
# zeroValue provided to each partition is unique from the one provided
# to the final reduce call
vals = self.mapPartitions(func).collect()
return reduce(safe_combOp, vals, zeroValue)
return reduce(combOp, vals, zeroValue)

def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
"""
Expand Down Expand Up @@ -1646,17 +1646,17 @@ def reduceByKeyLocally(self, func):
>>> sorted(rdd.reduceByKeyLocally(add).items())
[('a', 2), ('b', 1)]
"""
safe_func = fail_on_StopIteration(func)
func = fail_on_stopiteration(func)

def reducePartition(iterator):
m = {}
for k, v in iterator:
m[k] = safe_func(m[k], v) if k in m else v
m[k] = func(m[k], v) if k in m else v
yield m

def mergeMaps(m1, m2):
for k, v in m2.items():
m1[k] = safe_func(m1[k], v) if k in m1 else v
m1[k] = func(m1[k], v) if k in m1 else v
return m1
return self.mapPartitions(reducePartition).reduce(mergeMaps)

Expand Down Expand Up @@ -1858,7 +1858,6 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
>>> sorted(x.combineByKey(to_list, append, extend).collect())
[('a', [1, 2]), ('b', [1])]
"""

if numPartitions is None:
numPartitions = self._defaultReducePartitions()

Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import pyspark.heapq3 as heapq
from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \
CompressedSerializer, AutoBatchedSerializer
from pyspark.util import fail_on_StopIteration
from pyspark.util import fail_on_stopiteration


try:
Expand Down Expand Up @@ -95,9 +95,9 @@ class Aggregator(object):
"""

def __init__(self, createCombiner, mergeValue, mergeCombiners):
self.createCombiner = fail_on_StopIteration(createCombiner)
self.mergeValue = fail_on_StopIteration(mergeValue)
self.mergeCombiners = fail_on_StopIteration(mergeCombiners)
self.createCombiner = fail_on_stopiteration(createCombiner)
self.mergeValue = fail_on_stopiteration(mergeValue)
self.mergeCombiners = fail_on_stopiteration(mergeCombiners)


class SimpleAggregator(Aggregator):
Expand Down
9 changes: 5 additions & 4 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,16 @@ def majorMinorVersion(sparkVersion):
" version numbers.")


def fail_on_StopIteration(f):
""" wraps f to make it safe (= does not lead to data loss) to use inside a for loop
make StopIteration's raised inside f explicit
def fail_on_stopiteration(f):
"""
Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError'
prevents silent loss of data when 'f' is used in a for loop
"""
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except StopIteration as exc:
raise RuntimeError('StopIteration in client code', exc)
raise RuntimeError("Caught StopIteration thrown from user's code; failing the task", exc)

return wrapper

Expand Down

0 comments on commit d739eea

Please sign in to comment.