diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9b34a6578f0db..90904d3741521 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -655,12 +655,14 @@ def foreachPartition(self, f): >>> def f(iterator): ... for x in iterator: ... print x - ... yield None >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ def func(it): - f(it) - return iter([]) + r = f(it) + try: + return iter(r) + except TypeError: + return iter([]) self.mapPartitions(func).count() # Force evaluation def collect(self):