diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 137e2fe93c790..8467feb91d144 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -34,8 +34,8 @@ class ArrowPythonRunner( protected override val timeZoneId: String, protected override val workerConf: Map[String, String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets) - with PythonArrowInput - with PythonArrowOutput { + with BasicPythonArrowInput + with BasicPythonArrowOutput { override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index e3d8a943d8cf2..2661896ececc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -48,7 +48,7 @@ class CoGroupedArrowPythonRunner( conf: Map[String, String]) extends BasePythonRunner[ (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets) - with PythonArrowOutput { + with BasicPythonArrowOutput { override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 79365080f8cb3..6168d0f867adb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -32,15 +32,21 @@ import org.apache.spark.util.Utils /** * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from - * JVM (an iterator of internal rows) to Python (Arrow). + * JVM (an iterator of internal rows + additional data if required) to Python (Arrow). */ -private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[InternalRow], _] => +private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected val workerConf: Map[String, String] protected val schema: StructType protected val timeZoneId: String + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[IN]): Unit + protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { // Write config for the worker as a number of key -> value pairs of strings stream.writeInt(workerConf.size) @@ -53,7 +59,7 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna protected override def newWriterThread( env: SparkEnv, worker: Socket, - inputIterator: Iterator[Iterator[InternalRow]], + inputIterator: Iterator[IN], partitionIndex: Int, context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { @@ -74,17 +80,8 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() - while (inputIterator.hasNext) { - val nextBatch = inputIterator.next() - - while (nextBatch.hasNext) { - arrowWriter.write(nextBatch.next()) - } + writeIteratorToArrowStream(root, writer, dataOut, inputIterator) - arrowWriter.finish() - writer.writeBatch() - arrowWriter.reset() - } // end writes footer to the output stream and doesn't clean any resources. // It could throw exception if the output stream is closed, so it should be // in the try block. @@ -107,3 +104,27 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna } } } + +private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[InternalRow]] { + self: BasePythonRunner[Iterator[InternalRow], _] => + + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[Iterator[InternalRow]]): Unit = { + val arrowWriter = ArrowWriter.create(root) + + while (inputIterator.hasNext) { + val nextBatch = inputIterator.next() + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) + } + + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index d06a0d012990b..339f114539c28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -33,12 +33,14 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column /** * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from - * Python (Arrow) to JVM (ColumnarBatch). + * Python (Arrow) to JVM (output type being deserialized from ColumnarBatch). */ -private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatch] => +private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] => protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } + protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT + protected def newReaderIterator( stream: DataInputStream, writerThread: WriterThread, @@ -47,7 +49,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc worker: Socket, pid: Option[Int], releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[ColumnarBatch] = { + context: TaskContext): Iterator[OUT] = { new ReaderIterator( stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { @@ -74,7 +76,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc super.handleEndOfDataSection() } - protected override def read(): ColumnarBatch = { + protected override def read(): OUT = { if (writerThread.exception.isDefined) { throw writerThread.exception.get } @@ -84,7 +86,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc if (batchLoaded) { val batch = new ColumnarBatch(vectors) batch.setNumRows(root.getRowCount) - batch + deserializeColumnarBatch(batch, schema) } else { reader.close(false) allocator.close() @@ -108,7 +110,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc throw handlePythonException() case SpecialLengths.END_OF_DATA_SECTION => handleEndOfDataSection() - null + null.asInstanceOf[OUT] } } } catch handleException @@ -116,3 +118,11 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc } } } + +private[python] trait BasicPythonArrowOutput extends PythonArrowOutput[ColumnarBatch] { + self: BasePythonRunner[_, ColumnarBatch] => + + protected def deserializeColumnarBatch( + batch: ColumnarBatch, + schema: StructType): ColumnarBatch = batch +}