From 8597bbaf7d29de234f3f29e2929ce79c1ec99075 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 30 Mar 2016 23:29:20 -0700 Subject: [PATCH] fix udt with udf --- .../python/BatchPythonEvaluation.scala | 27 ++++++++++++------- .../sql/execution/python/EvaluatePython.scala | 10 +++++++ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index a537ed3bc4199..c9ab40a0a9abf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.TaskContext -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonFunction, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -72,8 +72,6 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip - // Most of the inputs are primitives, do not use memo for better performance - val pickle = new Pickler(false) // flatten all the arguments val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] @@ -89,21 +87,30 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c }.toArray }.toArray val projection = newMutableProjection(allInputs, child.output)() + val schema = StructType(dataTypes.map(dt => StructField("", dt))) + val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) + // enable memo iff we serialize the row with schema (schema and class should be memorized) + val pickle = new Pickler(needConversion) // Input iterator to Python: input rows are grouped so we send them in batches to Python. // For each row, add it to the queue. val inputIterator = iter.grouped(100).map { inputRows => val toBePickled = inputRows.map { inputRow => queue.add(inputRow) val row = projection(inputRow) - val fields = new Array[Any](row.numFields) - var i = 0 - while (i < row.numFields) { - val dt = dataTypes(i) - fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) - i += 1 + if (needConversion) { + EvaluatePython.toJava(row, schema) + } else { + // fast path for these types that does not need conversion in Python + val fields = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + val dt = dataTypes(i) + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) + i += 1 + } + fields } - fields }.toArray pickle.dumps(toBePickled) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 582f42042f622..f3d1c44b25b4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -70,6 +70,16 @@ object EvaluatePython { } } + def needConversionInPython(dt: DataType): Boolean = dt match { + case DateType | TimestampType => true + case _: StructType => true + case _: UserDefinedType[_] => true + case ArrayType(elementType, _) => needConversionInPython(elementType) + case MapType(keyType, valueType, _) => + needConversionInPython(keyType) || needConversionInPython(valueType) + case _ => false + } + /** * Helper for converting from Catalyst type to java type suitable for Pyrolite. */