Skip to content

Commit

Permalink
Merge pull request apache#311 from tmyklebu/master
Browse files Browse the repository at this point in the history
SPARK-991: Report information gleaned from a Python stacktrace in the UI

Scala:

- Added setCallSite/clearCallSite to SparkContext and JavaSparkContext.
  These functions mutate a LocalProperty called "externalCallSite."
- Add a wrapper, getCallSite, that checks for an externalCallSite and, if
  none is found, calls the usual Utils.formatSparkCallSite.
- Change everything that calls Utils.formatSparkCallSite to call
  getCallSite instead. Except getCallSite.
- Add wrappers to setCallSite/clearCallSite wrappers to JavaSparkContext.

Python:

- Add a gruesome hack to rdd.py that inspects the traceback and guesses
  what you want to see in the UI.
- Add a RAII wrapper around said gruesome hack that calls
  setCallSite/clearCallSite as appropriate.
- Wire said RAII wrapper up around three calls into the Scala code.
  I'm not sure that I hit all the spots with the RAII wrapper. I'm also
  not sure that my gruesome hack does exactly what we want.

One could also approach this change by refactoring
runJob/submitJob/runApproximateJob to take a call site, then threading
that parameter through everything that needs to know it.

One might object to the pointless-looking wrappers in JavaSparkContext.
Unfortunately, I can't directly access the SparkContext from
Python---or, if I can, I don't know how---so I need to wrap everything
that matters in JavaSparkContext.

Conflicts:
	core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
  • Loading branch information
mateiz committed Jan 2, 2014
2 parents 3713f81 + fec0166 commit ca67909
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 15 deletions.
26 changes: 23 additions & 3 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,26 @@ class SparkContext(
conf.getOption("spark.home").orElse(Option(System.getenv("SPARK_HOME")))
}

/**
* Support function for API backtraces.
*/
def setCallSite(site: String) {
setLocalProperty("externalCallSite", site)
}

/**
* Support function for API backtraces.
*/
def clearCallSite() {
setLocalProperty("externalCallSite", null)
}

private[spark] def getCallSite(): String = {
val callSite = getLocalProperty("externalCallSite")
if (callSite == null) return Utils.formatSparkCallSite
callSite
}

/**
* Run a function on a given set of partitions in an RDD and pass the results to the given
* handler function. This is the main entry point for all actions in Spark. The allowLocal
Expand All @@ -740,7 +760,7 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.formatSparkCallSite
val callSite = getCallSite
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
Expand Down Expand Up @@ -824,7 +844,7 @@ class SparkContext(
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
val callSite = Utils.formatSparkCallSite
val callSite = getCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout,
Expand All @@ -844,7 +864,7 @@ class SparkContext(
resultFunc: => R): SimpleFutureAction[R] =
{
val cleanF = clean(processPartition)
val callSite = Utils.formatSparkCallSite
val callSite = getCallSite
val waiter = dagScheduler.submitJob(
rdd,
(context: TaskContext, iter: Iterator[T]) => cleanF(iter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,20 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* changed at runtime.
*/
def getConf: SparkConf = sc.getConf

/**
* Pass-through to SparkContext.setCallSite. For API support only.
*/
def setCallSite(site: String) {
sc.setCallSite(site)
}

/**
* Pass-through to SparkContext.setCallSite. For API support only.
*/
def clearCallSite() {
sc.clearCallSite()
}
}

object JavaSparkContext {
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ abstract class RDD[T: ClassTag](
private var storageLevel: StorageLevel = StorageLevel.NONE

/** Record user function generating this RDD. */
@transient private[spark] val origin = Utils.formatSparkCallSite
@transient private[spark] val origin = sc.getCallSite

private[spark] def elementClassTag: ClassTag[T] = classTag[T]

Expand Down
66 changes: 55 additions & 11 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import sys
import shlex
import traceback
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
Expand All @@ -39,6 +40,46 @@

__all__ = ["RDD"]

def _extract_concise_traceback():
tb = traceback.extract_stack()
if len(tb) == 0:
return "I'm lost!"
# HACK: This function is in a file called 'rdd.py' in the top level of
# everything PySpark. Just trim off the directory name and assume
# everything in that tree is PySpark guts.
file, line, module, what = tb[len(tb) - 1]
sparkpath = os.path.dirname(file)
first_spark_frame = len(tb) - 1
for i in range(0, len(tb)):
file, line, fun, what = tb[i]
if file.startswith(sparkpath):
first_spark_frame = i
break
if first_spark_frame == 0:
file, line, fun, what = tb[0]
return "%s at %s:%d" % (fun, file, line)
sfile, sline, sfun, swhat = tb[first_spark_frame]
ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
return "%s at %s:%d" % (sfun, ufile, uline)

_spark_stack_depth = 0

class _JavaStackTrace(object):
def __init__(self, sc):
self._traceback = _extract_concise_traceback()
self._context = sc

def __enter__(self):
global _spark_stack_depth
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(self._traceback)
_spark_stack_depth += 1

def __exit__(self, type, value, tb):
global _spark_stack_depth
_spark_stack_depth -= 1
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(None)

class RDD(object):
"""
Expand Down Expand Up @@ -401,7 +442,8 @@ def collect(self):
"""
Return a list that contains all of the elements in this RDD.
"""
bytesInJava = self._jrdd.collect().iterator()
with _JavaStackTrace(self.context) as st:
bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))

def _collect_iterator_through_file(self, iterator):
Expand Down Expand Up @@ -582,13 +624,14 @@ def takeUpToNum(iterator):
# TODO(shivaram): Similar to the scala implementation, update the take
# method to scan multiple splits based on an estimate of how many elements
# we have per-split.
for partition in range(mapped._jrdd.splits().size()):
partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
partitionsToTake[0] = partition
iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break
with _JavaStackTrace(self.context) as st:
for partition in range(mapped._jrdd.splits().size()):
partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
partitionsToTake[0] = partition
iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break
return items[:num]

def first(self):
Expand Down Expand Up @@ -765,9 +808,10 @@ def add_shuffle_key(split, iterator):
yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
with _JavaStackTrace(self.context) as st:
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if
Expand Down

0 comments on commit ca67909

Please sign in to comment.