From a2304e45d31d3d802a3976b9cd1ea0f72a2d604b Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 14 Sep 2022 14:17:52 +0900 Subject: [PATCH] [SPARK-40414][SQL][PYTHON] More generic type on PythonArrowInput and PythonArrowOutput ### What changes were proposed in this pull request? This PR proposes to change PythonArrowInput and PythonArrowOutput to be more generic to cover the complex data type on both input and output. This is a baseline work for #37863. ### Why are the changes needed? The traits PythonArrowInput and PythonArrowOutput can be further generalized to cover complex data type on both input and output. E.g. Not all operators would have simple InternalRow as input data to pass to Python worker and vice versa for output data. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #37864 from HeartSaVioR/SPARK-40414. Authored-by: Jungtaek Lim Signed-off-by: Jungtaek Lim --- .../execution/python/ArrowPythonRunner.scala | 4 +- .../python/CoGroupedArrowPythonRunner.scala | 2 +- .../execution/python/PythonArrowInput.scala | 47 ++++++++++++++----- .../execution/python/PythonArrowOutput.scala | 22 ++++++--- 4 files changed, 53 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 137e2fe93c790..8467feb91d144 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -34,8 +34,8 @@ class ArrowPythonRunner( protected override val timeZoneId: String, protected override val workerConf: Map[String, String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets) - with PythonArrowInput - with PythonArrowOutput { + with BasicPythonArrowInput + with BasicPythonArrowOutput { override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index e3d8a943d8cf2..2661896ececc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -48,7 +48,7 @@ class CoGroupedArrowPythonRunner( conf: Map[String, String]) extends BasePythonRunner[ (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets) - with PythonArrowOutput { + with BasicPythonArrowOutput { override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 79365080f8cb3..6168d0f867adb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -32,15 +32,21 @@ import org.apache.spark.util.Utils /** * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from - * JVM (an iterator of internal rows) to Python (Arrow). + * JVM (an iterator of internal rows + additional data if required) to Python (Arrow). */ -private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[InternalRow], _] => +private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected val workerConf: Map[String, String] protected val schema: StructType protected val timeZoneId: String + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[IN]): Unit + protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { // Write config for the worker as a number of key -> value pairs of strings stream.writeInt(workerConf.size) @@ -53,7 +59,7 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna protected override def newWriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[Iterator[InternalRow]], + inputIterator: Iterator[IN], partitionIndex: Int, context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { @@ -74,17 +80,8 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() - while (inputIterator.hasNext) { - val nextBatch = inputIterator.next() - - while (nextBatch.hasNext) { - arrowWriter.write(nextBatch.next()) - } + writeIteratorToArrowStream(root, writer, dataOut, inputIterator) - arrowWriter.finish() - writer.writeBatch() - arrowWriter.reset() - } // end writes footer to the output stream and doesn't clean any resources. // It could throw exception if the output stream is closed, so it should be // in the try block. @@ -107,3 +104,27 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna } } } + +private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[InternalRow]] { + self: BasePythonRunner[Iterator[InternalRow], _] => + + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[Iterator[InternalRow]]): Unit = { + val arrowWriter = ArrowWriter.create(root) + + while (inputIterator.hasNext) { + val nextBatch = inputIterator.next() + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) + } + + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index d06a0d012990b..339f114539c28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -33,12 +33,14 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column /** * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from - * Python (Arrow) to JVM (ColumnarBatch). + * Python (Arrow) to JVM (output type being deserialized from ColumnarBatch). */ -private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatch] => +private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] => protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } + protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT + protected def newReaderIterator( stream: DataInputStream, writerThread: WriterThread, @@ -47,7 +49,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc worker: Socket, pid: Option[Int], releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[ColumnarBatch] = { + context: TaskContext): Iterator[OUT] = { new ReaderIterator( stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { @@ -74,7 +76,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc super.handleEndOfDataSection() } - protected override def read(): ColumnarBatch = { + protected override def read(): OUT = { if (writerThread.exception.isDefined) { throw writerThread.exception.get } @@ -84,7 +86,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc if (batchLoaded) { val batch = new ColumnarBatch(vectors) batch.setNumRows(root.getRowCount) - batch + deserializeColumnarBatch(batch, schema) } else { reader.close(false) allocator.close() @@ -108,7 +110,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc throw handlePythonException() case SpecialLengths.END_OF_DATA_SECTION => handleEndOfDataSection() - null + null.asInstanceOf[OUT] } } } catch handleException @@ -116,3 +118,11 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc } } } + +private[python] trait BasicPythonArrowOutput extends PythonArrowOutput[ColumnarBatch] { + self: BasePythonRunner[_, ColumnarBatch] => + + protected def deserializeColumnarBatch( + batch: ColumnarBatch, + schema: StructType): ColumnarBatch = batch +}