-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch #12057
Changes from 6 commits
f6b7373
8e6e5bc
8dc1adf
dd71ba9
8597bba
72a5ec0
876f9f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,7 +59,7 @@ private[spark] class PythonRDD( | |
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) | ||
|
||
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { | ||
val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false) | ||
val runner = PythonRunner(func, bufferSize, reuse_worker) | ||
runner.compute(firstParent.iterator(split, context), split.index, context) | ||
} | ||
} | ||
|
@@ -78,21 +78,41 @@ private[spark] case class PythonFunction( | |
accumulator: Accumulator[JList[Array[Byte]]]) | ||
|
||
/** | ||
* A helper class to run Python UDFs in Spark. | ||
* A wrapper for chained Python functions (from bottom to top). | ||
* @param funcs | ||
*/ | ||
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) | ||
|
||
private[spark] object PythonRunner { | ||
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { | ||
new PythonRunner( | ||
Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0))) | ||
} | ||
} | ||
|
||
/** | ||
* A helper class to run Python mapPartition/UDFs in Spark. | ||
* | ||
* funcs is a list of independent Python functions, each one of them is a list of chained Python | ||
* functions (from bottom to top). | ||
*/ | ||
private[spark] class PythonRunner( | ||
funcs: Seq[PythonFunction], | ||
funcs: Seq[ChainedPythonFunctions], | ||
bufferSize: Int, | ||
reuse_worker: Boolean, | ||
rowBased: Boolean) | ||
isUDF: Boolean, | ||
argOffsets: Array[Array[Int]]) | ||
extends Logging { | ||
|
||
require(funcs.length == argOffsets.length, "numArgs should have the same length as funcs") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. numArgs -> argOffsets |
||
|
||
// All the Python functions should have the same exec, version and envvars. | ||
private val envVars = funcs.head.envVars | ||
private val pythonExec = funcs.head.pythonExec | ||
private val pythonVer = funcs.head.pythonVer | ||
private val envVars = funcs.head.funcs.head.envVars | ||
private val pythonExec = funcs.head.funcs.head.pythonExec | ||
private val pythonVer = funcs.head.funcs.head.pythonVer | ||
|
||
private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF | ||
// TODO: support accumulator in multiple UDF | ||
private val accumulator = funcs.head.funcs.head.accumulator | ||
|
||
def compute( | ||
inputIterator: Iterator[_], | ||
|
@@ -232,8 +252,8 @@ private[spark] class PythonRunner( | |
|
||
@volatile private var _exception: Exception = null | ||
|
||
private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet | ||
private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala) | ||
private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet | ||
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) | ||
|
||
setDaemon(true) | ||
|
||
|
@@ -284,11 +304,25 @@ private[spark] class PythonRunner( | |
} | ||
dataOut.flush() | ||
// Serialized command: | ||
dataOut.writeInt(if (rowBased) 1 else 0) | ||
dataOut.writeInt(funcs.length) | ||
funcs.foreach { f => | ||
dataOut.writeInt(f.command.length) | ||
dataOut.write(f.command) | ||
if (isUDF) { | ||
dataOut.writeInt(1) | ||
dataOut.writeInt(funcs.length) | ||
funcs.zip(argOffsets).foreach { case (chained, offsets) => | ||
dataOut.writeInt(offsets.length) | ||
offsets.foreach { offset => | ||
dataOut.writeInt(offset) | ||
} | ||
dataOut.writeInt(chained.funcs.length) | ||
chained.funcs.foreach { f => | ||
dataOut.writeInt(f.command.length) | ||
dataOut.write(f.command) | ||
} | ||
} | ||
} else { | ||
dataOut.writeInt(0) | ||
val command = funcs.head.funcs.head.command | ||
dataOut.writeInt(command.length) | ||
dataOut.write(command) | ||
} | ||
// Data values | ||
PythonRDD.writeIteratorToStream(inputIterator, dataOut) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
from pyspark.broadcast import Broadcast, _broadcastRegistry | ||
from pyspark.files import SparkFiles | ||
from pyspark.serializers import write_with_length, write_int, read_long, \ | ||
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer | ||
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer | ||
from pyspark import shuffle | ||
|
||
pickleSer = PickleSerializer() | ||
|
@@ -59,7 +59,54 @@ def read_command(serializer, file): | |
|
||
def chain(f, g): | ||
"""chain two function together """ | ||
return lambda x: g(f(x)) | ||
return lambda *a: g(f(*a)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Woah, didn't know that you could do varargs lambdas. Cool! |
||
|
||
|
||
def wrap_udf(f, return_type): | ||
if return_type.needConversion(): | ||
toInternal = return_type.toInternal | ||
return lambda *a: toInternal(f(*a)) | ||
else: | ||
return lambda *a: f(*a) | ||
|
||
|
||
def read_single_udf(pickleSer, infile): | ||
num_arg = read_int(infile) | ||
arg_offsets = [read_int(infile) for i in range(num_arg)] | ||
row_func = None | ||
for i in range(read_int(infile)): | ||
f, return_type = read_command(pickleSer, infile) | ||
if row_func is None: | ||
row_func = f | ||
else: | ||
row_func = chain(row_func, f) | ||
# the last returnType will be the return type of UDF | ||
return arg_offsets, wrap_udf(row_func, return_type) | ||
|
||
|
||
def read_udfs(pickleSer, infile): | ||
num_udfs = read_int(infile) | ||
if num_udfs == 1: | ||
# fast path for single UDF | ||
_, udf = read_single_udf(pickleSer, infile) | ||
mapper = lambda a: udf(*a) | ||
else: | ||
udfs = {} | ||
call_udf = [] | ||
for i in range(num_udfs): | ||
arg_offsets, udf = read_single_udf(pickleSer, infile) | ||
udfs['f%d' % i] = udf | ||
args = ["a[%d]" % o for o in arg_offsets] | ||
call_udf.append("f%d(%s)" % (i, ", ".join(args))) | ||
# Create function like this: | ||
# lambda a: (f0(a0), f1(a1, a2), f2(a3)) | ||
mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clever! This is a neat trick. |
||
mapper = eval(mapper_str, udfs) | ||
|
||
func = lambda _, it: map(mapper, it) | ||
ser = BatchedSerializer(PickleSerializer(), 100) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What serializer did we use before? 100 seems arbitrary here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before this patch, we use AutoBatchedSerialzier, which could hold thousands of rows (holding more rows in JVM, may cause OOM). |
||
# profiling is not supported for UDF | ||
return func, None, ser, ser | ||
|
||
|
||
def main(infile, outfile): | ||
|
@@ -107,21 +154,10 @@ def main(infile, outfile): | |
_broadcastRegistry.pop(bid) | ||
|
||
_accumulatorRegistry.clear() | ||
row_based = read_int(infile) | ||
num_commands = read_int(infile) | ||
if row_based: | ||
profiler = None # profiling is not supported for UDF | ||
row_func = None | ||
for i in range(num_commands): | ||
f, returnType, deserializer = read_command(pickleSer, infile) | ||
if row_func is None: | ||
row_func = f | ||
else: | ||
row_func = chain(row_func, f) | ||
serializer = deserializer | ||
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) | ||
is_sql_udf = read_int(infile) | ||
if is_sql_udf: | ||
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile) | ||
else: | ||
assert num_commands == 1 | ||
func, profiler, deserializer, serializer = read_command(pickleSer, infile) | ||
|
||
init_time = time.time() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,16 +18,17 @@ | |
package org.apache.spark.sql.execution.python | ||
|
||
import scala.collection.JavaConverters._ | ||
import scala.collection.mutable.ArrayBuffer | ||
|
||
import net.razorvine.pickle.{Pickler, Unpickler} | ||
|
||
import org.apache.spark.TaskContext | ||
import org.apache.spark.api.python.{PythonFunction, PythonRunner} | ||
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner} | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.execution.SparkPlan | ||
import org.apache.spark.sql.types.{StructField, StructType} | ||
import org.apache.spark.sql.types.{DataType, StructField, StructType} | ||
|
||
|
||
/** | ||
|
@@ -40,20 +41,20 @@ import org.apache.spark.sql.types.{StructField, StructType} | |
* we drain the queue to find the original input row. Note that if the Python process is way too | ||
* slow, this could lead to the queue growing unbounded and eventually run out of memory. | ||
*/ | ||
case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) | ||
case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) | ||
extends SparkPlan { | ||
|
||
def children: Seq[SparkPlan] = child :: Nil | ||
|
||
private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = { | ||
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { | ||
udf.children match { | ||
case Seq(u: PythonUDF) => | ||
val (fs, children) = collectFunctions(u) | ||
(fs ++ Seq(udf.func), children) | ||
val (chained, children) = collectFunctions(u) | ||
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) | ||
case children => | ||
// There should not be any other UDFs, or the children can't be evaluated directly. | ||
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) | ||
(Seq(udf.func), udf.children) | ||
(ChainedPythonFunctions(Seq(udf.func)), udf.children) | ||
} | ||
} | ||
|
||
|
@@ -69,39 +70,78 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: | |
// combine input with output from Python. | ||
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() | ||
|
||
val (pyFuncs, children) = collectFunctions(udf) | ||
|
||
val pickle = new Pickler | ||
val currentRow = newMutableProjection(children, child.output)() | ||
val fields = children.map(_.dataType) | ||
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) | ||
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip | ||
|
||
// flatten all the arguments | ||
val allInputs = new ArrayBuffer[Expression] | ||
val dataTypes = new ArrayBuffer[DataType] | ||
val argOffsets = inputs.map { input => | ||
input.map { e => | ||
if (allInputs.exists(_.semanticEquals(e))) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the worst-case this loop is N^2, but N is probably pretty small so it probably doesn't matter compared to other perf. issues impacting Python UDFs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. |
||
allInputs.indexWhere(_.semanticEquals(e)) | ||
} else { | ||
allInputs += e | ||
dataTypes += e.dataType | ||
allInputs.length - 1 | ||
} | ||
}.toArray | ||
}.toArray | ||
val projection = newMutableProjection(allInputs, child.output)() | ||
val schema = StructType(dataTypes.map(dt => StructField("", dt))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure whether the name of this struct field will ever appear / get used anywhere, but if you wanted to provide a nicer name I suppose you could pull in the name or .toString from the expression in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not used anywhere, we will convert the row into tuple, then passing the items into UDFs. |
||
val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) | ||
|
||
// enable memo iff we serialize the row with schema (schema and class should be memorized) | ||
val pickle = new Pickler(needConversion) | ||
// Input iterator to Python: input rows are grouped so we send them in batches to Python. | ||
// For each row, add it to the queue. | ||
val inputIterator = iter.grouped(100).map { inputRows => | ||
val toBePickled = inputRows.map { row => | ||
queue.add(row) | ||
EvaluatePython.toJava(currentRow(row), schema) | ||
val toBePickled = inputRows.map { inputRow => | ||
queue.add(inputRow) | ||
val row = projection(inputRow) | ||
if (needConversion) { | ||
EvaluatePython.toJava(row, schema) | ||
} else { | ||
// fast path for these types that does not need conversion in Python | ||
val fields = new Array[Any](row.numFields) | ||
var i = 0 | ||
while (i < row.numFields) { | ||
val dt = dataTypes(i) | ||
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) | ||
i += 1 | ||
} | ||
fields | ||
} | ||
}.toArray | ||
pickle.dumps(toBePickled) | ||
} | ||
|
||
val context = TaskContext.get() | ||
|
||
// Output iterator for results from Python. | ||
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true) | ||
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets) | ||
.compute(inputIterator, context.partitionId(), context) | ||
|
||
val unpickle = new Unpickler | ||
val row = new GenericMutableRow(1) | ||
val mutableRow = new GenericMutableRow(1) | ||
val joined = new JoinedRow | ||
val resultType = if (udfs.length == 1) { | ||
udfs.head.dataType | ||
} else { | ||
StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) | ||
} | ||
val resultProj = UnsafeProjection.create(output, output) | ||
|
||
outputIterator.flatMap { pickedResult => | ||
val unpickledBatch = unpickle.loads(pickedResult) | ||
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala | ||
}.map { result => | ||
row(0) = EvaluatePython.fromJava(result, udf.dataType) | ||
val row = if (udfs.length == 1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than evaluating this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you do this, you could reduce the scope of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comparing evaluate Python UDF, I think this does not matter, JIT compiler could predict this branch pretty easy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough. |
||
// fast path for single UDF | ||
mutableRow(0) = EvaluatePython.fromJava(result, resultType) | ||
mutableRow | ||
} else { | ||
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] | ||
} | ||
resultProj(joined(queue.poll(), row)) | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, do you mind adding scaldoc for these two new parameters?