diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 84bd0f7ffdf64..4d6a97e255d34 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -728,6 +728,26 @@ class SparkContext( conf.getOption("spark.home").orElse(Option(System.getenv("SPARK_HOME"))) } + /** + * Support function for API backtraces. + */ + def setCallSite(site: String) { + setLocalProperty("externalCallSite", site) + } + + /** + * Support function for API backtraces. + */ + def clearCallSite() { + setLocalProperty("externalCallSite", null) + } + + private[spark] def getCallSite(): String = { + val callSite = getLocalProperty("externalCallSite") + if (callSite == null) return Utils.formatSparkCallSite + callSite + } + /** * Run a function on a given set of partitions in an RDD and pass the results to the given * handler function. This is the main entry point for all actions in Spark. The allowLocal @@ -740,7 +760,7 @@ class SparkContext( partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { - val callSite = Utils.formatSparkCallSite + val callSite = getCallSite val cleanedFunc = clean(func) logInfo("Starting job: " + callSite) val start = System.nanoTime @@ -824,7 +844,7 @@ class SparkContext( func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], timeout: Long): PartialResult[R] = { - val callSite = Utils.formatSparkCallSite + val callSite = getCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, @@ -844,7 +864,7 @@ class SparkContext( resultFunc: => R): SimpleFutureAction[R] = { val cleanF = clean(processPartition) - val callSite = Utils.formatSparkCallSite + val callSite = getCallSite val waiter = dagScheduler.submitJob( rdd, (context: TaskContext, iter: Iterator[T]) => cleanF(iter), diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 0680a065a3082..5be5317f40e7e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -411,6 +411,20 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * changed at runtime. */ def getConf: SparkConf = sc.getConf + + /** + * Pass-through to SparkContext.setCallSite. For API support only. + */ + def setCallSite(site: String) { + sc.setCallSite(site) + } + + /** + * Pass-through to SparkContext.setCallSite. For API support only. + */ + def clearCallSite() { + sc.clearCallSite() + } } object JavaSparkContext { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6a7b0f8a86b6d..3f41b66279987 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -953,7 +953,7 @@ abstract class RDD[T: ClassTag]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Record user function generating this RDD. */ - @transient private[spark] val origin = Utils.formatSparkCallSite + @transient private[spark] val origin = sc.getCallSite private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index f87923e6fa4eb..6fb4a7b3be25d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -23,6 +23,7 @@ import os import sys import shlex +import traceback from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread @@ -39,6 +40,46 @@ __all__ = ["RDD"] +def _extract_concise_traceback(): + tb = traceback.extract_stack() + if len(tb) == 0: + return "I'm lost!" + # HACK: This function is in a file called 'rdd.py' in the top level of + # everything PySpark. Just trim off the directory name and assume + # everything in that tree is PySpark guts. + file, line, module, what = tb[len(tb) - 1] + sparkpath = os.path.dirname(file) + first_spark_frame = len(tb) - 1 + for i in range(0, len(tb)): + file, line, fun, what = tb[i] + if file.startswith(sparkpath): + first_spark_frame = i + break + if first_spark_frame == 0: + file, line, fun, what = tb[0] + return "%s at %s:%d" % (fun, file, line) + sfile, sline, sfun, swhat = tb[first_spark_frame] + ufile, uline, ufun, uwhat = tb[first_spark_frame-1] + return "%s at %s:%d" % (sfun, ufile, uline) + +_spark_stack_depth = 0 + +class _JavaStackTrace(object): + def __init__(self, sc): + self._traceback = _extract_concise_traceback() + self._context = sc + + def __enter__(self): + global _spark_stack_depth + if _spark_stack_depth == 0: + self._context._jsc.setCallSite(self._traceback) + _spark_stack_depth += 1 + + def __exit__(self, type, value, tb): + global _spark_stack_depth + _spark_stack_depth -= 1 + if _spark_stack_depth == 0: + self._context._jsc.setCallSite(None) class RDD(object): """ @@ -401,7 +442,8 @@ def collect(self): """ Return a list that contains all of the elements in this RDD. """ - bytesInJava = self._jrdd.collect().iterator() + with _JavaStackTrace(self.context) as st: + bytesInJava = self._jrdd.collect().iterator() return list(self._collect_iterator_through_file(bytesInJava)) def _collect_iterator_through_file(self, iterator): @@ -582,13 +624,14 @@ def takeUpToNum(iterator): # TODO(shivaram): Similar to the scala implementation, update the take # method to scan multiple splits based on an estimate of how many elements # we have per-split. - for partition in range(mapped._jrdd.splits().size()): - partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) - partitionsToTake[0] = partition - iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() - items.extend(mapped._collect_iterator_through_file(iterator)) - if len(items) >= num: - break + with _JavaStackTrace(self.context) as st: + for partition in range(mapped._jrdd.splits().size()): + partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) + partitionsToTake[0] = partition + iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() + items.extend(mapped._collect_iterator_through_file(iterator)) + if len(items) >= num: + break return items[:num] def first(self): @@ -765,9 +808,10 @@ def add_shuffle_key(split, iterator): yield outputSerializer.dumps(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True - pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, - id(partitionFunc)) + with _JavaStackTrace(self.context) as st: + pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() + partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, + id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) # This is required so that id(partitionFunc) remains unique, even if