-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch #12057
Changes from all commits
f6b7373
8e6e5bc
8dc1adf
dd71ba9
8597bba
72a5ec0
876f9f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
from pyspark.broadcast import Broadcast, _broadcastRegistry | ||
from pyspark.files import SparkFiles | ||
from pyspark.serializers import write_with_length, write_int, read_long, \ | ||
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer | ||
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer | ||
from pyspark import shuffle | ||
|
||
pickleSer = PickleSerializer() | ||
|
@@ -59,7 +59,54 @@ def read_command(serializer, file): | |
|
||
def chain(f, g): | ||
"""chain two function together """ | ||
return lambda x: g(f(x)) | ||
return lambda *a: g(f(*a)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Woah, didn't know that you could do varargs lambdas. Cool! |
||
|
||
|
||
def wrap_udf(f, return_type): | ||
if return_type.needConversion(): | ||
toInternal = return_type.toInternal | ||
return lambda *a: toInternal(f(*a)) | ||
else: | ||
return lambda *a: f(*a) | ||
|
||
|
||
def read_single_udf(pickleSer, infile): | ||
num_arg = read_int(infile) | ||
arg_offsets = [read_int(infile) for i in range(num_arg)] | ||
row_func = None | ||
for i in range(read_int(infile)): | ||
f, return_type = read_command(pickleSer, infile) | ||
if row_func is None: | ||
row_func = f | ||
else: | ||
row_func = chain(row_func, f) | ||
# the last returnType will be the return type of UDF | ||
return arg_offsets, wrap_udf(row_func, return_type) | ||
|
||
|
||
def read_udfs(pickleSer, infile): | ||
num_udfs = read_int(infile) | ||
if num_udfs == 1: | ||
# fast path for single UDF | ||
_, udf = read_single_udf(pickleSer, infile) | ||
mapper = lambda a: udf(*a) | ||
else: | ||
udfs = {} | ||
call_udf = [] | ||
for i in range(num_udfs): | ||
arg_offsets, udf = read_single_udf(pickleSer, infile) | ||
udfs['f%d' % i] = udf | ||
args = ["a[%d]" % o for o in arg_offsets] | ||
call_udf.append("f%d(%s)" % (i, ", ".join(args))) | ||
# Create function like this: | ||
# lambda a: (f0(a0), f1(a1, a2), f2(a3)) | ||
mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clever! This is a neat trick. |
||
mapper = eval(mapper_str, udfs) | ||
|
||
func = lambda _, it: map(mapper, it) | ||
ser = BatchedSerializer(PickleSerializer(), 100) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What serializer did we use before? 100 seems arbitrary here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before this patch, we use AutoBatchedSerialzier, which could hold thousands of rows (holding more rows in JVM, may cause OOM). |
||
# profiling is not supported for UDF | ||
return func, None, ser, ser | ||
|
||
|
||
def main(infile, outfile): | ||
|
@@ -107,21 +154,10 @@ def main(infile, outfile): | |
_broadcastRegistry.pop(bid) | ||
|
||
_accumulatorRegistry.clear() | ||
row_based = read_int(infile) | ||
num_commands = read_int(infile) | ||
if row_based: | ||
profiler = None # profiling is not supported for UDF | ||
row_func = None | ||
for i in range(num_commands): | ||
f, returnType, deserializer = read_command(pickleSer, infile) | ||
if row_func is None: | ||
row_func = f | ||
else: | ||
row_func = chain(row_func, f) | ||
serializer = deserializer | ||
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) | ||
is_sql_udf = read_int(infile) | ||
if is_sql_udf: | ||
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile) | ||
else: | ||
assert num_commands == 1 | ||
func, profiler, deserializer, serializer = read_command(pickleSer, infile) | ||
|
||
init_time = time.time() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,16 +18,17 @@ | |
package org.apache.spark.sql.execution.python | ||
|
||
import scala.collection.JavaConverters._ | ||
import scala.collection.mutable.ArrayBuffer | ||
|
||
import net.razorvine.pickle.{Pickler, Unpickler} | ||
|
||
import org.apache.spark.TaskContext | ||
import org.apache.spark.api.python.{PythonFunction, PythonRunner} | ||
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner} | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.execution.SparkPlan | ||
import org.apache.spark.sql.types.{StructField, StructType} | ||
import org.apache.spark.sql.types.{DataType, StructField, StructType} | ||
|
||
|
||
/** | ||
|
@@ -40,20 +41,20 @@ import org.apache.spark.sql.types.{StructField, StructType} | |
* we drain the queue to find the original input row. Note that if the Python process is way too | ||
* slow, this could lead to the queue growing unbounded and eventually run out of memory. | ||
*/ | ||
case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) | ||
case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) | ||
extends SparkPlan { | ||
|
||
def children: Seq[SparkPlan] = child :: Nil | ||
|
||
private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = { | ||
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { | ||
udf.children match { | ||
case Seq(u: PythonUDF) => | ||
val (fs, children) = collectFunctions(u) | ||
(fs ++ Seq(udf.func), children) | ||
val (chained, children) = collectFunctions(u) | ||
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) | ||
case children => | ||
// There should not be any other UDFs, or the children can't be evaluated directly. | ||
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) | ||
(Seq(udf.func), udf.children) | ||
(ChainedPythonFunctions(Seq(udf.func)), udf.children) | ||
} | ||
} | ||
|
||
|
@@ -69,39 +70,78 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: | |
// combine input with output from Python. | ||
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() | ||
|
||
val (pyFuncs, children) = collectFunctions(udf) | ||
|
||
val pickle = new Pickler | ||
val currentRow = newMutableProjection(children, child.output)() | ||
val fields = children.map(_.dataType) | ||
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) | ||
val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip | ||
|
||
// flatten all the arguments | ||
val allInputs = new ArrayBuffer[Expression] | ||
val dataTypes = new ArrayBuffer[DataType] | ||
val argOffsets = inputs.map { input => | ||
input.map { e => | ||
if (allInputs.exists(_.semanticEquals(e))) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the worst-case this loop is N^2, but N is probably pretty small so it probably doesn't matter compared to other perf. issues impacting Python UDFs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. |
||
allInputs.indexWhere(_.semanticEquals(e)) | ||
} else { | ||
allInputs += e | ||
dataTypes += e.dataType | ||
allInputs.length - 1 | ||
} | ||
}.toArray | ||
}.toArray | ||
val projection = newMutableProjection(allInputs, child.output)() | ||
val schema = StructType(dataTypes.map(dt => StructField("", dt))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure whether the name of this struct field will ever appear / get used anywhere, but if you wanted to provide a nicer name I suppose you could pull in the name or .toString from the expression in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not used anywhere, we will convert the row into tuple, then passing the items into UDFs. |
||
val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) | ||
|
||
// enable memo iff we serialize the row with schema (schema and class should be memorized) | ||
val pickle = new Pickler(needConversion) | ||
// Input iterator to Python: input rows are grouped so we send them in batches to Python. | ||
// For each row, add it to the queue. | ||
val inputIterator = iter.grouped(100).map { inputRows => | ||
val toBePickled = inputRows.map { row => | ||
queue.add(row) | ||
EvaluatePython.toJava(currentRow(row), schema) | ||
val toBePickled = inputRows.map { inputRow => | ||
queue.add(inputRow) | ||
val row = projection(inputRow) | ||
if (needConversion) { | ||
EvaluatePython.toJava(row, schema) | ||
} else { | ||
// fast path for these types that does not need conversion in Python | ||
val fields = new Array[Any](row.numFields) | ||
var i = 0 | ||
while (i < row.numFields) { | ||
val dt = dataTypes(i) | ||
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) | ||
i += 1 | ||
} | ||
fields | ||
} | ||
}.toArray | ||
pickle.dumps(toBePickled) | ||
} | ||
|
||
val context = TaskContext.get() | ||
|
||
// Output iterator for results from Python. | ||
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true) | ||
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets) | ||
.compute(inputIterator, context.partitionId(), context) | ||
|
||
val unpickle = new Unpickler | ||
val row = new GenericMutableRow(1) | ||
val mutableRow = new GenericMutableRow(1) | ||
val joined = new JoinedRow | ||
val resultType = if (udfs.length == 1) { | ||
udfs.head.dataType | ||
} else { | ||
StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) | ||
} | ||
val resultProj = UnsafeProjection.create(output, output) | ||
|
||
outputIterator.flatMap { pickedResult => | ||
val unpickledBatch = unpickle.loads(pickedResult) | ||
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala | ||
}.map { result => | ||
row(0) = EvaluatePython.fromJava(result, udf.dataType) | ||
val row = if (udfs.length == 1) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than evaluating this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you do this, you could reduce the scope of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comparing evaluate Python UDF, I think this does not matter, JIT compiler could predict this branch pretty easy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough. |
||
// fast path for single UDF | ||
mutableRow(0) = EvaluatePython.fromJava(result, resultType) | ||
mutableRow | ||
} else { | ||
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] | ||
} | ||
resultProj(joined(queue.poll(), row)) | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, do you mind adding scaldoc for these two new parameters?