Skip to content

Commit

Permalink
improve performance, address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Mar 31, 2016
1 parent 8e6e5bc commit 8dc1adf
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 50 deletions.
45 changes: 30 additions & 15 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,30 +77,42 @@ private[spark] case class PythonFunction(
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])

/**
* A wrapper for chained Python functions (from bottom to top).
* @param funcs
*/
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])

object PythonRunner {
private[spark] object PythonRunner {
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
new PythonRunner(Seq(Seq(func)), bufferSize, reuse_worker, false, Seq(1))
new PythonRunner(
Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Seq(Seq(0)))
}
}

/**
* A helper class to run Python mapPartition/UDFs in Spark.
*
* funcs is a list of independent Python functions, each one of them is a list of chained Python
* functions (from bottom to top).
*/
private[spark] class PythonRunner(
funcs: Seq[Seq[PythonFunction]],
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuse_worker: Boolean,
isUDF: Boolean,
numArgs: Seq[Int])
argOffsets: Seq[Seq[Int]])
extends Logging {

require(funcs.length == argOffsets.length, "numArgs should have the same length as funcs")

// All the Python functions should have the same exec, version and envvars.
private val envVars = funcs.head.head.envVars
private val pythonExec = funcs.head.head.pythonExec
private val pythonVer = funcs.head.head.pythonVer
private val envVars = funcs.head.funcs.head.envVars
private val pythonExec = funcs.head.funcs.head.pythonExec
private val pythonVer = funcs.head.funcs.head.pythonVer

private val accumulator = funcs.head.head.accumulator // TODO: support accumulator in multiple UDF
// TODO: support accumulator in multiple UDF
private val accumulator = funcs.head.funcs.head.accumulator

def compute(
inputIterator: Iterator[_],
Expand Down Expand Up @@ -240,8 +252,8 @@ private[spark] class PythonRunner(

@volatile private var _exception: Exception = null

private val pythonIncludes = funcs.flatMap(_.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.flatMap(_.broadcastVars.asScala))
private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))

setDaemon(true)

Expand Down Expand Up @@ -295,17 +307,20 @@ private[spark] class PythonRunner(
if (isUDF) {
dataOut.writeInt(1)
dataOut.writeInt(funcs.length)
funcs.zip(numArgs).foreach { case (fs, numArg) =>
dataOut.writeInt(numArg)
dataOut.writeInt(fs.length)
fs.foreach { f =>
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
dataOut.writeInt(offsets.length)
offsets.foreach { offset =>
dataOut.writeInt(offset)
}
dataOut.writeInt(chained.funcs.length)
chained.funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
}
}
} else {
dataOut.writeInt(0)
val command = funcs.head.head.command
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
}
Expand Down
36 changes: 19 additions & 17 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@ def chain(f, g):


def wrap_udf(f, return_type):
return lambda *a: return_type.toInternal(f(*a))
toInternal = return_type.toInternal
return lambda *a: toInternal(f(*a))


def read_single_udf(pickleSer, infile):
num_arg = read_int(infile)
arg_offsets = [read_int(infile) for i in range(num_arg)]
row_func = None
for i in range(read_int(infile)):
f, return_type = read_command(pickleSer, infile)
Expand All @@ -76,27 +78,27 @@ def read_single_udf(pickleSer, infile):
else:
row_func = chain(row_func, f)
# the last returnType will be the return type of UDF
return num_arg, wrap_udf(row_func, return_type)
return arg_offsets, wrap_udf(row_func, return_type)


def read_udfs(pickleSer, infile):
num_udfs = read_int(infile)
udfs = []
offset = 0
for i in range(num_udfs):
num_arg, udf = read_single_udf(pickleSer, infile)
udfs.append((offset, offset + num_arg, udf))
offset += num_arg

if num_udfs == 1:
udf = udfs[0][2]

# fast path for single UDF
def mapper(args):
return udf(*args)
_, udf = read_single_udf(pickleSer, infile)
mapper = lambda a: udf(*a)
else:
def mapper(args):
return tuple(udf(*args[start:end]) for start, end, udf in udfs)
udfs = {}
call_udf = []
for i in range(num_udfs):
arg_offsets, udf = read_single_udf(pickleSer, infile)
udfs['f%d' % i] = udf
args = ["a[%d]" % o for o in arg_offsets]
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
# Create function like this:
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
mapper = eval(mapper_str, udfs)

func = lambda _, it: map(mapper, it)
ser = AutoBatchedSerializer(PickleSerializer())
Expand Down Expand Up @@ -149,8 +151,8 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)

_accumulatorRegistry.clear()
is_udf = read_int(infile)
if is_udf:
is_sql_udf = read_int(infile)
if is_sql_udf:
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
else:
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{PythonFunction, PythonRunner}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonFunction, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.{DataType, StructField, StructType}


/**
Expand All @@ -45,15 +46,15 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c

def children: Seq[SparkPlan] = child :: Nil

private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
val (fs, children) = collectFunctions(u)
(fs ++ Seq(udf.func), children)
val (chained, children) = collectFunctions(u)
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
case children =>
// There should not be any other UDFs, or the children can't be evaluated directly.
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
(Seq(udf.func), udf.children)
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
}
}

Expand All @@ -69,30 +70,48 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()

val (pyFuncs, children) = udfs.map(collectFunctions).unzip
val numArgs = children.map(_.length)
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip

val pickle = new Pickler
// Most of the inputs are primitives, do not use memo for better performance
val pickle = new Pickler(false)
// flatten all the arguments
val allChildren = children.flatMap(x => x)
val currentRow = newMutableProjection(allChildren, child.output)()
val fields = allChildren.map(_.dataType)
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
val argOffsets = inputs.map { input =>
input.map { e =>
if (allInputs.exists(_.semanticEquals(e))) {
allInputs.indexWhere(_.semanticEquals(e))
} else {
allInputs += e
dataTypes += e.dataType
allInputs.length - 1
}
}
}
val projection = newMutableProjection(allInputs, child.output)()

// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
val inputIterator = iter.grouped(100).map { inputRows =>
val toBePickled = inputRows.map { row =>
queue.add(row)
EvaluatePython.toJava(currentRow(row), schema)
val inputIterator = iter.grouped(1024).map { inputRows =>
val toBePickled = inputRows.map { inputRow =>
queue.add(inputRow)
val row = projection(inputRow)
val fields = new Array[Any](row.numFields)
var i = 0
while (i < row.numFields) {
val dt = dataTypes(i)
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
i += 1
}
fields
}.toArray
pickle.dumps(toBePickled)
}

val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, numArgs)
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
Expand Down

0 comments on commit 8dc1adf

Please sign in to comment.