Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch #12057

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
val runner = PythonRunner(func, bufferSize, reuse_worker)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
Expand All @@ -78,21 +78,41 @@ private[spark] case class PythonFunction(
accumulator: Accumulator[JList[Array[Byte]]])

/**
* A helper class to run Python UDFs in Spark.
* A wrapper for chained Python functions (from bottom to top).
* @param funcs
*/
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])

private[spark] object PythonRunner {
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
new PythonRunner(
Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
}
}

/**
* A helper class to run Python mapPartition/UDFs in Spark.
*
* funcs is a list of independent Python functions, each one of them is a list of chained Python
* functions (from bottom to top).
*/
private[spark] class PythonRunner(
funcs: Seq[PythonFunction],
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuse_worker: Boolean,
rowBased: Boolean)
isUDF: Boolean,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, do you mind adding scaldoc for these two new parameters?

argOffsets: Array[Array[Int]])
extends Logging {

require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")

// All the Python functions should have the same exec, version and envvars.
private val envVars = funcs.head.envVars
private val pythonExec = funcs.head.pythonExec
private val pythonVer = funcs.head.pythonVer
private val envVars = funcs.head.funcs.head.envVars
private val pythonExec = funcs.head.funcs.head.pythonExec
private val pythonVer = funcs.head.funcs.head.pythonVer

private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF
// TODO: support accumulator in multiple UDF
private val accumulator = funcs.head.funcs.head.accumulator

def compute(
inputIterator: Iterator[_],
Expand Down Expand Up @@ -232,8 +252,8 @@ private[spark] class PythonRunner(

@volatile private var _exception: Exception = null

private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)
private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))

setDaemon(true)

Expand Down Expand Up @@ -284,11 +304,25 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(if (rowBased) 1 else 0)
dataOut.writeInt(funcs.length)
funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
if (isUDF) {
dataOut.writeInt(1)
dataOut.writeInt(funcs.length)
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
dataOut.writeInt(offsets.length)
offsets.foreach { offset =>
dataOut.writeInt(offset)
}
dataOut.writeInt(chained.funcs.length)
chained.funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
}
}
} else {
dataOut.writeInt(0)
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
}
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,8 +1649,7 @@ def sort_array(col, asc=True):
# ---------------------------- User Defined Function ----------------------------------

def _wrap_function(sc, func, returnType):
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, returnType, ser)
command = (func, returnType)
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
Expand Down
12 changes: 11 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_udf2(self):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])

def test_chained_python_udf(self):
def test_chained_udf(self):
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
self.assertEqual(row[0], 2)
Expand All @@ -314,6 +314,16 @@ def test_chained_python_udf(self):
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
self.assertEqual(row[0], 6)

def test_multiple_udfs(self):
self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
self.assertEqual(tuple(row), (2, 4))
[row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
self.assertEqual(tuple(row), (4, 12))
self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
self.assertEqual(tuple(row), (6, 5))

def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
Expand Down
68 changes: 52 additions & 16 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
from pyspark import shuffle

pickleSer = PickleSerializer()
Expand Down Expand Up @@ -59,7 +59,54 @@ def read_command(serializer, file):

def chain(f, g):
"""chain two function together """
return lambda x: g(f(x))
return lambda *a: g(f(*a))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woah, didn't know that you could do varargs lambdas. Cool!



def wrap_udf(f, return_type):
if return_type.needConversion():
toInternal = return_type.toInternal
return lambda *a: toInternal(f(*a))
else:
return lambda *a: f(*a)


def read_single_udf(pickleSer, infile):
num_arg = read_int(infile)
arg_offsets = [read_int(infile) for i in range(num_arg)]
row_func = None
for i in range(read_int(infile)):
f, return_type = read_command(pickleSer, infile)
if row_func is None:
row_func = f
else:
row_func = chain(row_func, f)
# the last returnType will be the return type of UDF
return arg_offsets, wrap_udf(row_func, return_type)


def read_udfs(pickleSer, infile):
num_udfs = read_int(infile)
if num_udfs == 1:
# fast path for single UDF
_, udf = read_single_udf(pickleSer, infile)
mapper = lambda a: udf(*a)
else:
udfs = {}
call_udf = []
for i in range(num_udfs):
arg_offsets, udf = read_single_udf(pickleSer, infile)
udfs['f%d' % i] = udf
args = ["a[%d]" % o for o in arg_offsets]
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
# Create function like this:
# lambda a: (f0(a0), f1(a1, a2), f2(a3))
mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever! This is a neat trick.

mapper = eval(mapper_str, udfs)

func = lambda _, it: map(mapper, it)
ser = BatchedSerializer(PickleSerializer(), 100)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What serializer did we use before? 100 seems arbitrary here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this patch, we use AutoBatchedSerialzier, which could hold thousands of rows (holding more rows in JVM, may cause OOM).
The 100 is used on Java side.

# profiling is not supported for UDF
return func, None, ser, ser


def main(infile, outfile):
Expand Down Expand Up @@ -107,21 +154,10 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)

_accumulatorRegistry.clear()
row_based = read_int(infile)
num_commands = read_int(infile)
if row_based:
profiler = None # profiling is not supported for UDF
row_func = None
for i in range(num_commands):
f, returnType, deserializer = read_command(pickleSer, infile)
if row_func is None:
row_func = f
else:
row_func = chain(row_func, f)
serializer = deserializer
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
is_sql_udf = read_int(infile)
if is_sql_udf:
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
else:
assert num_commands == 1
func, profiler, deserializer, serializer = read_command(pickleSer, infile)

init_time = time.time()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
case e @ python.EvaluatePython(udf, child, _) =>
python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case e @ python.EvaluatePython(udfs, child, _) =>
python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case _ => Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{PythonFunction, PythonRunner}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.{DataType, StructField, StructType}


/**
Expand All @@ -40,20 +41,20 @@ import org.apache.spark.sql.types.{StructField, StructType}
* we drain the queue to find the original input row. Note that if the Python process is way too
* slow, this could lead to the queue growing unbounded and eventually run out of memory.
*/
case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends SparkPlan {

def children: Seq[SparkPlan] = child :: Nil

private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
val (fs, children) = collectFunctions(u)
(fs ++ Seq(udf.func), children)
val (chained, children) = collectFunctions(u)
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
case children =>
// There should not be any other UDFs, or the children can't be evaluated directly.
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
(Seq(udf.func), udf.children)
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
}
}

Expand All @@ -69,39 +70,78 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()

val (pyFuncs, children) = collectFunctions(udf)

val pickle = new Pickler
val currentRow = newMutableProjection(children, child.output)()
val fields = children.map(_.dataType)
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip

// flatten all the arguments
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
val argOffsets = inputs.map { input =>
input.map { e =>
if (allInputs.exists(_.semanticEquals(e))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the worst-case this loop is N^2, but N is probably pretty small so it probably doesn't matter compared to other perf. issues impacting Python UDFs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.

allInputs.indexWhere(_.semanticEquals(e))
} else {
allInputs += e
dataTypes += e.dataType
allInputs.length - 1
}
}.toArray
}.toArray
val projection = newMutableProjection(allInputs, child.output)()
val schema = StructType(dataTypes.map(dt => StructField("", dt)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether the name of this struct field will ever appear / get used anywhere, but if you wanted to provide a nicer name I suppose you could pull in the name or .toString from the expression in allInputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not used anywhere, we will convert the row into tuple, then passing the items into UDFs.

val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)

// enable memo iff we serialize the row with schema (schema and class should be memorized)
val pickle = new Pickler(needConversion)
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
val inputIterator = iter.grouped(100).map { inputRows =>
val toBePickled = inputRows.map { row =>
queue.add(row)
EvaluatePython.toJava(currentRow(row), schema)
val toBePickled = inputRows.map { inputRow =>
queue.add(inputRow)
val row = projection(inputRow)
if (needConversion) {
EvaluatePython.toJava(row, schema)
} else {
// fast path for these types that does not need conversion in Python
val fields = new Array[Any](row.numFields)
var i = 0
while (i < row.numFields) {
val dt = dataTypes(i)
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
i += 1
}
fields
}
}.toArray
pickle.dumps(toBePickled)
}

val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
val row = new GenericMutableRow(1)
val mutableRow = new GenericMutableRow(1)
val joined = new JoinedRow
val resultType = if (udfs.length == 1) {
udfs.head.dataType
} else {
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
}
val resultProj = UnsafeProjection.create(output, output)

outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
row(0) = EvaluatePython.fromJava(result, udf.dataType)
val row = if (udfs.length == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than evaluating this if condition for every row, could we lift this out of the map and perform it once while building the RDD DAG? i.e. assign the result of line 108 to a variable and have the if be the last return value of this block?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do this, you could reduce the scope of the mutableRow created up on line 99, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing evaluate Python UDF, I think this does not matter, JIT compiler could predict this branch pretty easy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough.

// fast path for single UDF
mutableRow(0) = EvaluatePython.fromJava(result, resultType)
mutableRow
} else {
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
}
resultProj(joined(queue.poll(), row))
}
}
Expand Down
Loading