diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index ccbca67656c8d..549fcb8e9559a 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -215,6 +215,21 @@ def addInPlace(self, value1, value2): COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) +class StatsParam(AccumulatorParam): + """StatsParam is used to merge pstats.Stats""" + + @staticmethod + def zero(value): + return None + + @staticmethod + def addInPlace(value1, value2): + if value1 is None: + return value2 + value1.add(value2) + return value1 + + class _UpdateRequestHandler(SocketServer.StreamRequestHandler): """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 5667154cb84a8..9dcccdd37a20a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -31,9 +31,11 @@ import warnings import heapq import bisect +import atexit from random import Random from math import sqrt, log, isinf, isnan +from pyspark.accumulators import StatsParam from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, CompressedSerializer @@ -2076,6 +2078,7 @@ class PipelinedRDD(RDD): >>> rdd.flatMap(lambda x: [x, x]).reduce(add) 20 """ + _created_profiles = [] def __init__(self, prev, func, preservesPartitioning=False): if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): @@ -2110,7 +2113,9 @@ def _jrdd(self): return self._jrdd_val if self._bypass_serializer: self._jrdd_deserializer = NoOpSerializer() - command = (self.func, self._prev_jrdd_deserializer, + enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" + profileStats = self.ctx.accumulator(None, StatsParam) if enable_profile else None + command = (self.func, profileStats, self._prev_jrdd_deserializer, self._jrdd_deserializer) ser = CloudPickleSerializer() pickled_command = ser.dumps(command) @@ -2128,8 +2133,39 @@ def _jrdd(self): self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() + + if enable_profile: + self._id = self._jrdd_val.id() + if not self._created_profiles: + dump_path = self.ctx._conf.get("spark.python.profile.dump") + if dump_path: + atexit.register(PipelinedRDD.dump_profile, dump_path) + else: + atexit.register(PipelinedRDD.show_profile) + self._created_profiles.append((self._id, profileStats)) + return self._jrdd_val + @classmethod + def show_profile(cls): + for id, acc in cls._created_profiles: + stats = acc.value + if stats: + print "="*60 + print "Profile of RDD" % id + print "="*60 + stats.sort_stats("tottime", "cumtime").print_stats() + + @classmethod + def dump_profile(cls, dump_path): + if not os.path.exists(dump_path): + os.makedirs(dump_path) + for id, acc in cls._created_profiles: + stats = acc.value + if stats: + path = os.path.join(dump_path, "rdd_%d.pstats" % id) + stats.dump_stats(path) + def id(self): if self._id is None: self._id = self._jrdd.id() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6805063e06798..6e5c1e17d2647 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,6 +23,9 @@ import time import socket import traceback +import cProfile +import pstats + # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. from pyspark.accumulators import _accumulatorRegistry @@ -73,10 +76,21 @@ def main(infile, outfile): _broadcastRegistry[bid] = Broadcast(bid, value) command = pickleSer._read_with_length(infile) - (func, deserializer, serializer) = command + (func, stats, deserializer, serializer) = command init_time = time.time() - iterator = deserializer.load_stream(infile) - serializer.dump_stream(func(split_index, iterator), outfile) + + def process(): + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) + + if stats: + p = cProfile.Profile() + p.runcall(process) + st = pstats.Stats(p) + st.stream = None # make it picklable + stats.add(st.strip_dirs()) + else: + process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)