Skip to content

Commit

Permalink
[SPARK-3478] [PySpark] Profile the Python tasks
Browse files Browse the repository at this point in the history
This patch add profiling support for PySpark, it will show the profiling results
before the driver exits, here is one example:

```
============================================================
Profile of RDD<id=3>
============================================================
         5146507 function calls (5146487 primitive calls) in 71.094 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  5144576   68.331    0.000   68.331    0.000 statcounter.py:44(merge)
       20    2.735    0.137   71.071    3.554 statcounter.py:33(__init__)
       20    0.017    0.001    0.017    0.001 {cPickle.dumps}
     1024    0.003    0.000    0.003    0.000 t.py:16(<lambda>)
       20    0.001    0.000    0.001    0.000 {reduce}
       21    0.001    0.000    0.001    0.000 {cPickle.loads}
       20    0.001    0.000    0.001    0.000 copy_reg.py:95(_slotnames)
       41    0.001    0.000    0.001    0.000 serializers.py:461(read_int)
       40    0.001    0.000    0.002    0.000 serializers.py:179(_batched)
       62    0.000    0.000    0.000    0.000 {method 'read' of 'file' objects}
       20    0.000    0.000   71.072    3.554 rdd.py:863(<lambda>)
       20    0.000    0.000    0.001    0.000 serializers.py:198(load_stream)
    40/20    0.000    0.000   71.072    3.554 rdd.py:2093(pipeline_func)
       41    0.000    0.000    0.002    0.000 serializers.py:130(load_stream)
       40    0.000    0.000   71.072    1.777 rdd.py:304(func)
       20    0.000    0.000   71.094    3.555 worker.py:82(process)
```

Also, use can show profile result manually by `sc.show_profiles()` or dump it into disk
by `sc.dump_profiles(path)`, such as

```python
>>> sc._conf.set("spark.python.profile", "true")
>>> rdd = sc.parallelize(range(100)).map(str)
>>> rdd.count()
100
>>> sc.show_profiles()
============================================================
Profile of RDD<id=1>
============================================================
         284 function calls (276 primitive calls) in 0.001 seconds

   Ordered by: internal time, cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        4    0.000    0.000    0.000    0.000 serializers.py:198(load_stream)
        4    0.000    0.000    0.000    0.000 {reduce}
     12/4    0.000    0.000    0.001    0.000 rdd.py:2092(pipeline_func)
        4    0.000    0.000    0.000    0.000 {cPickle.loads}
        4    0.000    0.000    0.000    0.000 {cPickle.dumps}
      104    0.000    0.000    0.000    0.000 rdd.py:852(<genexpr>)
        8    0.000    0.000    0.000    0.000 serializers.py:461(read_int)
       12    0.000    0.000    0.000    0.000 rdd.py:303(func)
```
The profiling is disabled by default, can be enabled by "spark.python.profile=true".

Also, users can dump the results into disks automatically for future analysis, by "spark.python.profile.dump=path_to_dump"

This is bugfix of apache#2351 cc JoshRosen

Author: Davies Liu <[email protected]>

Closes apache#2556 from davies/profiler and squashes the following commits:

e68df5a [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler
858e74c [Davies Liu] compatitable with python 2.6
7ef2aa0 [Davies Liu] bugfix, add tests for show_profiles and dump_profiles()
2b0daf2 [Davies Liu] fix docs
7a56c24 [Davies Liu] bugfix
cba9463 [Davies Liu] move show_profiles and dump_profiles to SparkContext
fb9565b [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler
116d52a [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler
09d02c3 [Davies Liu] Merge branch 'master' into profiler
c23865c [Davies Liu] Merge branch 'master' into profiler
15d6f18 [Davies Liu] add docs for two configs
dadee1a [Davies Liu] add docs string and clear profiles after show or dump
4f8309d [Davies Liu] address comment, add tests
0a5b6eb [Davies Liu] fix Python UDF
4b20494 [Davies Liu] add profile for python
  • Loading branch information
davies authored and JoshRosen committed Oct 1, 2014
1 parent d75496b commit c5414b6
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 7 deletions.
19 changes: 19 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,25 @@ Apart from these, the following properties are also available, and may be useful
used during aggregation goes above this amount, it will spill the data into disks.
</td>
</tr>
<tr>
<td><code>spark.python.profile</code></td>
<td>false</td>
<td>
Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`,
or it will be displayed before the driver exiting. It also can be dumped into disk by
`sc.dump_profiles(path)`. If some of the profile results had been displayed maually,
they will not be displayed automatically before driver exiting.
</td>
</tr>
<tr>
<td><code>spark.python.profile.dump</code></td>
<td>(none)</td>
<td>
The directory which is used to dump the profile result before driver exiting.
The results will be dumped as separated file for each RDD. They can be loaded
by ptats.Stats(). If this is specified, the profile result will not be displayed
automatically.
</tr>
<tr>
<td><code>spark.python.worker.reuse</code></td>
<td>true</td>
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,21 @@ def addInPlace(self, value1, value2):
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)


class PStatsParam(AccumulatorParam):
"""PStatsParam is used to merge pstats.Stats"""

@staticmethod
def zero(value):
return None

@staticmethod
def addInPlace(value1, value2):
if value1 is None:
return value2
value1.add(value2)
return value1


class _UpdateRequestHandler(SocketServer.StreamRequestHandler):

"""
Expand Down
39 changes: 38 additions & 1 deletion python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import sys
from threading import Lock
from tempfile import NamedTemporaryFile
import atexit

from pyspark import accumulators
from pyspark.accumulators import Accumulator
Expand All @@ -30,7 +31,6 @@
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer, CompressedSerializer
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call

Expand Down Expand Up @@ -192,6 +192,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()

# profiling stats collected for each PythonRDD
self._profile_stats = []

def _initialize_context(self, jconf):
"""
Initialize SparkContext in function to allow subclass specific initialization
Expand Down Expand Up @@ -792,6 +795,40 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))

def _add_profile(self, id, profileAcc):
if not self._profile_stats:
dump_path = self._conf.get("spark.python.profile.dump")
if dump_path:
atexit.register(self.dump_profiles, dump_path)
else:
atexit.register(self.show_profiles)

self._profile_stats.append([id, profileAcc, False])

def show_profiles(self):
""" Print the profile stats to stdout """
for i, (id, acc, showed) in enumerate(self._profile_stats):
stats = acc.value
if not showed and stats:
print "=" * 60
print "Profile of RDD<id=%d>" % id
print "=" * 60
stats.sort_stats("time", "cumulative").print_stats()
# mark it as showed
self._profile_stats[i][2] = True

def dump_profiles(self, path):
""" Dump the profile stats into directory `path`
"""
if not os.path.exists(path):
os.makedirs(path)
for id, acc, _ in self._profile_stats:
stats = acc.value
if stats:
p = os.path.join(path, "rdd_%d.pstats" % id)
stats.dump_stats(p)
self._profile_stats = []


def _test():
import atexit
Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
#

from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from itertools import chain, ifilter, imap
Expand All @@ -32,6 +31,7 @@
from random import Random
from math import sqrt, log, isinf, isnan

from pyspark.accumulators import PStatsParam
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer
Expand Down Expand Up @@ -2080,7 +2080,9 @@ def _jrdd(self):
return self._jrdd_val
if self._bypass_serializer:
self._jrdd_deserializer = NoOpSerializer()
command = (self.func, self._prev_jrdd_deserializer,
enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
command = (self.func, profileStats, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
# the serialized command will be compressed by broadcast
ser = CloudPickleSerializer()
Expand All @@ -2102,6 +2104,10 @@ def _jrdd(self):
self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()

if enable_profile:
self._id = self._jrdd_val.id()
self.ctx._add_profile(self._id, profileStats)
return self._jrdd_val

def id(self):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ def registerFunction(self, name, f, returnType=StringType()):
[Row(c0=4)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
command = (func,
command = (func, None,
BatchedSerializer(PickleSerializer(), 1024),
BatchedSerializer(PickleSerializer(), 1024))
ser = CloudPickleSerializer()
Expand Down
30 changes: 30 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,36 @@ def test_distinct(self):
self.assertEquals(result.count(), 3)


class TestProfiler(PySparkTestCase):

def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
conf = SparkConf().set("spark.python.profile", "true")
self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf)

def test_profiler(self):

def heavy_foo(x):
for i in range(1 << 20):
x = 1
rdd = self.sc.parallelize(range(100))
rdd.foreach(heavy_foo)
profiles = self.sc._profile_stats
self.assertEqual(1, len(profiles))
id, acc, _ = profiles[0]
stats = acc.value
self.assertTrue(stats is not None)
width, stat_list = stats.get_print_list([])
func_names = [func_name for fname, n, func_name in stat_list]
self.assertTrue("heavy_foo" in func_names)

self.sc.show_profiles()
d = tempfile.gettempdir()
self.sc.dump_profiles(d)
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))


class TestSQL(PySparkTestCase):

def setUp(self):
Expand Down
19 changes: 16 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import time
import socket
import traceback
import cProfile
import pstats

from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
Expand Down Expand Up @@ -90,10 +92,21 @@ def main(infile, outfile):
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
(func, deserializer, serializer) = command
(func, stats, deserializer, serializer) = command
init_time = time.time()
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)

def process():
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)

if stats:
p = cProfile.Profile()
p.runcall(process)
st = pstats.Stats(p)
st.stream = None # make it picklable
stats.add(st.strip_dirs())
else:
process()
except Exception:
try:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
Expand Down

0 comments on commit c5414b6

Please sign in to comment.