Skip to content

Commit

Permalink
[SPARK-40414][SQL][PYTHON] More generic type on PythonArrowInput and …
Browse files Browse the repository at this point in the history
…PythonArrowOutput

### What changes were proposed in this pull request?

This PR proposes to change PythonArrowInput and PythonArrowOutput to be more generic to cover the complex data type on both input and output. This is a baseline work for #37863.

### Why are the changes needed?

The traits PythonArrowInput and PythonArrowOutput can be further generalized to cover complex data type on both input and output. E.g. Not all operators would have simple InternalRow as input data to pass to Python worker and vice versa for output data.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes #37864 from HeartSaVioR/SPARK-40414.

Authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
HeartSaVioR committed Sep 14, 2022
1 parent d45b894 commit a2304e4
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class ArrowPythonRunner(
protected override val timeZoneId: String,
protected override val workerConf: Map[String, String])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets)
with PythonArrowInput
with PythonArrowOutput {
with BasicPythonArrowInput
with BasicPythonArrowOutput {

override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class CoGroupedArrowPythonRunner(
conf: Map[String, String])
extends BasePythonRunner[
(Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets)
with PythonArrowOutput {
with BasicPythonArrowOutput {

override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,21 @@ import org.apache.spark.util.Utils

/**
* A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from
* JVM (an iterator of internal rows) to Python (Arrow).
* JVM (an iterator of internal rows + additional data if required) to Python (Arrow).
*/
private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[InternalRow], _] =>
private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>
protected val workerConf: Map[String, String]

protected val schema: StructType

protected val timeZoneId: String

protected def writeIteratorToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
dataOut: DataOutputStream,
inputIterator: Iterator[IN]): Unit

protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
// Write config for the worker as a number of key -> value pairs of strings
stream.writeInt(workerConf.size)
Expand All @@ -53,7 +59,7 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna
protected override def newWriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[Iterator[InternalRow]],
inputIterator: Iterator[IN],
partitionIndex: Int,
context: TaskContext): WriterThread = {
new WriterThread(env, worker, inputIterator, partitionIndex, context) {
Expand All @@ -74,17 +80,8 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna
val writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()

while (inputIterator.hasNext) {
val nextBatch = inputIterator.next()

while (nextBatch.hasNext) {
arrowWriter.write(nextBatch.next())
}
writeIteratorToArrowStream(root, writer, dataOut, inputIterator)

arrowWriter.finish()
writer.writeBatch()
arrowWriter.reset()
}
// end writes footer to the output stream and doesn't clean any resources.
// It could throw exception if the output stream is closed, so it should be
// in the try block.
Expand All @@ -107,3 +104,27 @@ private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[Interna
}
}
}

private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[InternalRow]] {
self: BasePythonRunner[Iterator[InternalRow], _] =>

protected def writeIteratorToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
dataOut: DataOutputStream,
inputIterator: Iterator[Iterator[InternalRow]]): Unit = {
val arrowWriter = ArrowWriter.create(root)

while (inputIterator.hasNext) {
val nextBatch = inputIterator.next()

while (nextBatch.hasNext) {
arrowWriter.write(nextBatch.next())
}

arrowWriter.finish()
writer.writeBatch()
arrowWriter.reset()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column

/**
* A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from
* Python (Arrow) to JVM (ColumnarBatch).
* Python (Arrow) to JVM (output type being deserialized from ColumnarBatch).
*/
private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatch] =>
private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] =>

protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }

protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT

protected def newReaderIterator(
stream: DataInputStream,
writerThread: WriterThread,
Expand All @@ -47,7 +49,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
worker: Socket,
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[ColumnarBatch] = {
context: TaskContext): Iterator[OUT] = {

new ReaderIterator(
stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) {
Expand All @@ -74,7 +76,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
super.handleEndOfDataSection()
}

protected override def read(): ColumnarBatch = {
protected override def read(): OUT = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
Expand All @@ -84,7 +86,7 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
if (batchLoaded) {
val batch = new ColumnarBatch(vectors)
batch.setNumRows(root.getRowCount)
batch
deserializeColumnarBatch(batch, schema)
} else {
reader.close(false)
allocator.close()
Expand All @@ -108,11 +110,19 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
throw handlePythonException()
case SpecialLengths.END_OF_DATA_SECTION =>
handleEndOfDataSection()
null
null.asInstanceOf[OUT]
}
}
} catch handleException
}
}
}
}

private[python] trait BasicPythonArrowOutput extends PythonArrowOutput[ColumnarBatch] {
self: BasePythonRunner[_, ColumnarBatch] =>

protected def deserializeColumnarBatch(
batch: ColumnarBatch,
schema: StructType): ColumnarBatch = batch
}

0 comments on commit a2304e4

Please sign in to comment.