diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4a1cd0dba3116..587d5c7dd0fee 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -77,30 +77,42 @@ private[spark] case class PythonFunction( broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]]) +/** + * A wrapper for chained Python functions (from bottom to top). + * @param funcs + */ +private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) -object PythonRunner { +private[spark] object PythonRunner { def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { - new PythonRunner(Seq(Seq(func)), bufferSize, reuse_worker, false, Seq(1)) + new PythonRunner( + Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Seq(Seq(0))) } } /** * A helper class to run Python mapPartition/UDFs in Spark. + * + * funcs is a list of independent Python functions, each one of them is a list of chained Python + * functions (from bottom to top). */ private[spark] class PythonRunner( - funcs: Seq[Seq[PythonFunction]], + funcs: Seq[ChainedPythonFunctions], bufferSize: Int, reuse_worker: Boolean, isUDF: Boolean, - numArgs: Seq[Int]) + argOffsets: Seq[Seq[Int]]) extends Logging { + require(funcs.length == argOffsets.length, "numArgs should have the same length as funcs") + // All the Python functions should have the same exec, version and envvars. - private val envVars = funcs.head.head.envVars - private val pythonExec = funcs.head.head.pythonExec - private val pythonVer = funcs.head.head.pythonVer + private val envVars = funcs.head.funcs.head.envVars + private val pythonExec = funcs.head.funcs.head.pythonExec + private val pythonVer = funcs.head.funcs.head.pythonVer - private val accumulator = funcs.head.head.accumulator // TODO: support accumulator in multiple UDF + // TODO: support accumulator in multiple UDF + private val accumulator = funcs.head.funcs.head.accumulator def compute( inputIterator: Iterator[_], @@ -240,8 +252,8 @@ private[spark] class PythonRunner( @volatile private var _exception: Exception = null - private val pythonIncludes = funcs.flatMap(_.flatMap(_.pythonIncludes.asScala)).toSet - private val broadcastVars = funcs.flatMap(_.flatMap(_.broadcastVars.asScala)) + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) setDaemon(true) @@ -295,17 +307,20 @@ private[spark] class PythonRunner( if (isUDF) { dataOut.writeInt(1) dataOut.writeInt(funcs.length) - funcs.zip(numArgs).foreach { case (fs, numArg) => - dataOut.writeInt(numArg) - dataOut.writeInt(fs.length) - fs.foreach { f => + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => dataOut.writeInt(f.command.length) dataOut.write(f.command) } } } else { dataOut.writeInt(0) - val command = funcs.head.head.command + val command = funcs.head.funcs.head.command dataOut.writeInt(command.length) dataOut.write(command) } diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b7bb1edc40890..849516c8106be 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -63,11 +63,13 @@ def chain(f, g): def wrap_udf(f, return_type): - return lambda *a: return_type.toInternal(f(*a)) + toInternal = return_type.toInternal + return lambda *a: toInternal(f(*a)) def read_single_udf(pickleSer, infile): num_arg = read_int(infile) + arg_offsets = [read_int(infile) for i in range(num_arg)] row_func = None for i in range(read_int(infile)): f, return_type = read_command(pickleSer, infile) @@ -76,27 +78,27 @@ def read_single_udf(pickleSer, infile): else: row_func = chain(row_func, f) # the last returnType will be the return type of UDF - return num_arg, wrap_udf(row_func, return_type) + return arg_offsets, wrap_udf(row_func, return_type) def read_udfs(pickleSer, infile): num_udfs = read_int(infile) - udfs = [] - offset = 0 - for i in range(num_udfs): - num_arg, udf = read_single_udf(pickleSer, infile) - udfs.append((offset, offset + num_arg, udf)) - offset += num_arg - if num_udfs == 1: - udf = udfs[0][2] - # fast path for single UDF - def mapper(args): - return udf(*args) + _, udf = read_single_udf(pickleSer, infile) + mapper = lambda a: udf(*a) else: - def mapper(args): - return tuple(udf(*args[start:end]) for start, end, udf in udfs) + udfs = {} + call_udf = [] + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + # Create function like this: + # lambda a: (f0(a0), f1(a1, a2), f2(a3)) + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) ser = AutoBatchedSerializer(PickleSerializer()) @@ -149,8 +151,8 @@ def main(infile, outfile): _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() - is_udf = read_int(infile) - if is_udf: + is_sql_udf = read_int(infile) + if is_sql_udf: func, profiler, deserializer, serializer = read_udfs(pickleSer, infile) else: func, profiler, deserializer, serializer = read_command(pickleSer, infile) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index a9c8bd4f6752b..180d2f375d01f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -18,16 +18,17 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.TaskContext -import org.apache.spark.api.python.{PythonFunction, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonFunction, PythonRunner} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructField, StructType} /** @@ -45,15 +46,15 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c def children: Seq[SparkPlan] = child :: Nil - private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = { + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => - val (fs, children) = collectFunctions(u) - (fs ++ Seq(udf.func), children) + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) case children => // There should not be any other UDFs, or the children can't be evaluated directly. assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) - (Seq(udf.func), udf.children) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) } } @@ -69,22 +70,40 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c // combine input with output from Python. val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() - val (pyFuncs, children) = udfs.map(collectFunctions).unzip - val numArgs = children.map(_.length) + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip - val pickle = new Pickler + // Most of the inputs are primitives, do not use memo for better performance + val pickle = new Pickler(false) // flatten all the arguments - val allChildren = children.flatMap(x => x) - val currentRow = newMutableProjection(allChildren, child.output)() - val fields = allChildren.map(_.dataType) - val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + } + } + val projection = newMutableProjection(allInputs, child.output)() // Input iterator to Python: input rows are grouped so we send them in batches to Python. // For each row, add it to the queue. - val inputIterator = iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { row => - queue.add(row) - EvaluatePython.toJava(currentRow(row), schema) + val inputIterator = iter.grouped(1024).map { inputRows => + val toBePickled = inputRows.map { inputRow => + queue.add(inputRow) + val row = projection(inputRow) + val fields = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + val dt = dataTypes(i) + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) + i += 1 + } + fields }.toArray pickle.dumps(toBePickled) } @@ -92,7 +111,7 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c val context = TaskContext.get() // Output iterator for results from Python. - val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, numArgs) + val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler