Skip to content

Commit

Permalink
Make state spillable in partitioned writer [databricks] (#8667)
Browse files Browse the repository at this point in the history
* Make state spillable in partitioned writer

Signed-off-by: Alessandro Bellina <[email protected]>

---------

Signed-off-by: Alessandro Bellina <[email protected]>
  • Loading branch information
abellina authored Jul 26, 2023
1 parent b8a07fe commit 5421a85
Show file tree
Hide file tree
Showing 22 changed files with 1,032 additions and 357 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import java.io.OutputStream

import scala.collection.mutable

import ai.rapids.cudf.{HostBufferConsumer, HostMemoryBuffer, NvtxColor, NvtxRange, Table, TableWriter}
import com.nvidia.spark.rapids.Arm.withResource
import ai.rapids.cudf.{HostBufferConsumer, HostMemoryBuffer, NvtxColor, NvtxRange, TableWriter}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRestoreOnRetry, withRetry, withRetryNoSplit}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FSDataOutputStream, Path}
import org.apache.hadoop.mapreduce.TaskAttemptContext

Expand All @@ -38,6 +40,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
* `org.apache.spark.sql.execution.datasources.OutputWriterFactory`.
*/
abstract class ColumnarOutputWriterFactory extends Serializable {
/** Returns the default partition flush size in bytes, format specific */
def partitionFlushSize(context: TaskAttemptContext): Long = 128L * 1024L * 1024L // 128M

/** Returns the file extension to be used when writing files out. */
def getFileExtension(context: TaskAttemptContext): String
Expand Down Expand Up @@ -67,14 +71,20 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext,
rangeName: String,
includeRetry: Boolean) extends HostBufferConsumer {

val tableWriter: TableWriter
val conf = context.getConfiguration
protected val tableWriter: TableWriter

private[this] val outputStream: FSDataOutputStream = {
protected val conf: Configuration = context.getConfiguration

// This is implemented as a method to make it easier to subclass
// ColumnarOutputWriter in the tests, and override this behavior.
protected def getOutputStream: FSDataOutputStream = {
val hadoopPath = new Path(path)
val fs = hadoopPath.getFileSystem(conf)
fs.create(hadoopPath, false)
}

protected val outputStream: FSDataOutputStream = getOutputStream

private[this] val tempBuffer = new Array[Byte](128 * 1024)
private[this] var anythingWritten = false
private[this] val buffers = mutable.Queue[(HostMemoryBuffer, Long)]()
Expand All @@ -93,146 +103,103 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext,
true
}

/**
* Persists a columnar batch. Invoked on the executor side. When writing to dynamically
* partitioned tables, dynamic partition columns are not included in columns to be written.
*
* NOTE: This method will close `batch`. We do this because we want
* to free GPU memory after the GPU has finished encoding the data but before
* it is written to the distributed filesystem. The GPU semaphore is released
* during the distributed filesystem transfer to allow other tasks to start/continue
* GPU processing.
*/
def writeAndClose(
batch: ColumnarBatch,
private[this] def updateStatistics(
writeStartTime: Long,
gpuTime: Long,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker]): Unit = {
var needToCloseBatch = true
try {
val writeStartTimestamp = System.nanoTime
val writeRange = new NvtxRange("File write", NvtxColor.YELLOW)
val gpuTime = try {
needToCloseBatch = false
writeBatch(batch)
} finally {
writeRange.close()
}

// Update statistics
val writeTime = System.nanoTime - writeStartTimestamp - gpuTime
statsTrackers.foreach {
case gpuTracker: GpuWriteTaskStatsTracker =>
gpuTracker.addWriteTime(writeTime)
gpuTracker.addGpuTime(gpuTime)
case _ =>
}
} finally {
if (needToCloseBatch) {
batch.close()
}
// Update statistics
val writeTime = System.nanoTime - writeStartTime - gpuTime
statsTrackers.foreach {
case gpuTracker: GpuWriteTaskStatsTracker =>
gpuTracker.addWriteTime(writeTime)
gpuTracker.addGpuTime(gpuTime)
case _ =>
}
}

protected def scanTableBeforeWrite(table: Table): Unit = {
protected def throwIfRebaseNeededInExceptionMode(batch: ColumnarBatch): Unit = {
// NOOP for now, but allows a child to override this
}


/**
* Writes the columnar batch and returns the time in ns taken to write
* Persists a columnar batch. Invoked on the executor side. When writing to dynamically
* partitioned tables, dynamic partition columns are not included in columns to be written.
*
* NOTE: This method will close `batch`. We do this because we want
* NOTE: This method will close `spillableBatch`. We do this because we want
* to free GPU memory after the GPU has finished encoding the data but before
* it is written to the distributed filesystem. The GPU semaphore is released
* during the distributed filesystem transfer to allow other tasks to start/continue
* GPU processing.
*
* @param batch Columnar batch that needs to be written
* @return time in ns taken to write the batch
*/
private[this] def writeBatch(batch: ColumnarBatch): Long = {
if (includeRetry) {
writeBatchWithRetry(batch)
} else {
writeBatchNoRetry(batch)
}
}

/** Apply any necessary casts before writing batch out */
def transform(cb: ColumnarBatch): Option[ColumnarBatch] = None

private[this] def writeBatchWithRetry(batch: ColumnarBatch): Long = {
val sb = SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
RmmRapidsRetryIterator.withRetry(sb, RmmRapidsRetryIterator.splitSpillableInHalfByRows) { sb =>
val cr = new CheckpointRestore {
override def checkpoint(): Unit = ()
override def restore(): Unit = dropBufferedData()
def writeSpillableAndClose(
spillableBatch: SpillableColumnarBatch,
statsTrackers: Seq[ColumnarWriteTaskStatsTracker]): Long = {
val writeStartTime = System.nanoTime
closeOnExcept(spillableBatch) { _ =>
val cb = withRetryNoSplit[ColumnarBatch] {
spillableBatch.getColumnarBatch()
}
val startTimestamp = System.nanoTime
withResource(sb.getColumnarBatch()) { cb =>
//TODO: we should really apply the transformations to cast timestamps
// to the expected types before spilling but we need a SpillableTable
// rather than a SpillableColumnBatch to be able to do that
// See https://github.com/NVIDIA/spark-rapids/issues/8262
RmmRapidsRetryIterator.withRestoreOnRetry(cr) {
withResource(new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)) { _ =>
scan(cb)
transform(cb) match {
case Some(transformed) =>
// because we created a new transformed batch, we need to make sure we close it
withResource(transformed) { _ =>
write(transformed)
}
case _ =>
write(cb)
}
}
}
// run pre-flight checks and update stats
withResource(cb) { _ =>
throwIfRebaseNeededInExceptionMode(cb)
// NOTE: it is imperative that `newBatch` is not in a retry block.
// Otherwise it WILL corrupt writers that generate metadata in this method (like delta)
statsTrackers.foreach(_.newBatch(path(), cb))
}
GpuSemaphore.releaseIfNecessary(TaskContext.get)
val gpuTime = System.nanoTime - startTimestamp
writeBufferedData()
gpuTime
}.sum
}

private[this] def writeBatchNoRetry(batch: ColumnarBatch): Long = {
var needToCloseBatch = true
try {
val startTimestamp = System.nanoTime
withResource(new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)) { _ =>
scan(batch)
transform(batch) match {
case Some(transformed) =>
// because we created a new transformed batch, we need to make sure we close it
withResource(transformed) { _ =>
write(transformed)
}
case _ =>
write(batch)
}
val gpuTime = if (includeRetry) {
//TODO: we should really apply the transformations to cast timestamps
// to the expected types before spilling but we need a SpillableTable
// rather than a SpillableColumnBatch to be able to do that
// See https://github.com/NVIDIA/spark-rapids/issues/8262
withRetry(spillableBatch, splitSpillableInHalfByRows) { attempt =>
withRestoreOnRetry(checkpointRestore) {
bufferBatchAndClose(attempt.getColumnarBatch())
}
}.sum
} else {
withResource(spillableBatch) { _ =>
bufferBatchAndClose(spillableBatch.getColumnarBatch())
}
}
// we successfully buffered to host memory, release the semaphore and write
// the buffered data to the FS
GpuSemaphore.releaseIfNecessary(TaskContext.get)
writeBufferedData()
updateStatistics(writeStartTime, gpuTime, statsTrackers)
spillableBatch.numRows()
}

// Batch is no longer needed, write process from here does not use GPU.
batch.close()
needToCloseBatch = false
GpuSemaphore.releaseIfNecessary(TaskContext.get)
val gpuTime = System.nanoTime - startTimestamp
writeBufferedData()
gpuTime
} finally {
if (needToCloseBatch) {
batch.close()
// protected for testing
protected[this] def bufferBatchAndClose(batch: ColumnarBatch): Long = {
val startTimestamp = System.nanoTime
withResource(new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)) { _ =>
withResource(transformAndClose(batch)) { maybeTransformed =>
encodeAndBufferToHost(maybeTransformed)
}
}
// time spent on GPU encoding to the host sink
System.nanoTime - startTimestamp
}

private def scan(batch: ColumnarBatch): Unit = {
withResource(GpuColumnVector.from(batch)) { table =>
scanTableBeforeWrite(table)
}
/** Apply any necessary casts before writing batch out */
def transformAndClose(cb: ColumnarBatch): ColumnarBatch = cb

private val checkpointRestore = new CheckpointRestore {
override def checkpoint(): Unit = ()
override def restore(): Unit = dropBufferedData()
}

private def write(batch: ColumnarBatch): Unit = {
private def encodeAndBufferToHost(batch: ColumnarBatch): Unit = {
withResource(GpuColumnVector.from(batch)) { table =>
// `anythingWritten` is set here as an indication that there was data at all
// to write, even if the `tableWriter.write` method fails. If we fail to write
// and the task fails, any output is going to be discarded anyway, so no data
// corruption to worry about. Otherwise, we should retry (OOM case).
// If we have nothing to write, we won't flip this flag to true and we will
// buffer an empty batch on close() to work around issues in cuDF
// where corrupt files can be written if nothing is encoded via the writer.
anythingWritten = true
tableWriter.write(table)
}
Expand All @@ -245,9 +212,10 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext,
def close(): Unit = {
if (!anythingWritten) {
// This prevents writing out bad files
writeBatch(GpuColumnVector.emptyBatch(dataSchema))
bufferBatchAndClose(GpuColumnVector.emptyBatch(dataSchema))
}
tableWriter.close()
GpuSemaphore.releaseIfNecessary(TaskContext.get())
writeBufferedData()
outputStream.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging {
override def getFileExtension(context: TaskAttemptContext): String = {
CodecConfig.from(context).getCodec.getExtension + ".parquet"
}

override def partitionFlushSize(context: TaskAttemptContext): Long =
context.getConfiguration.getLong("write.parquet.row-group-size-bytes",
128L * 1024L * 1024L) // 128M
}
}
}
Expand All @@ -306,9 +310,9 @@ class GpuParquetWriter(

val outputTimestampType = conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)

override def scanTableBeforeWrite(table: Table): Unit = {
(0 until table.getNumberOfColumns).foreach { i =>
val col = table.getColumn(i)
override def throwIfRebaseNeededInExceptionMode(batch: ColumnarBatch): Unit = {
val cols = GpuColumnVector.extractBases(batch)
cols.foreach { col =>
// if col is a day
if (dateRebaseException && RebaseHelper.isDateRebaseNeededInWrite(col)) {
throw DataSourceUtils.newRebaseExceptionInWrite("Parquet")
Expand All @@ -320,12 +324,14 @@ class GpuParquetWriter(
}
}

override def transform(batch: ColumnarBatch): Option[ColumnarBatch] = {
val transformedCols = GpuColumnVector.extractColumns(batch).safeMap { cv =>
new GpuColumnVector(cv.dataType, deepTransformColumn(cv.getBase, cv.dataType))
.asInstanceOf[org.apache.spark.sql.vectorized.ColumnVector]
override def transformAndClose(batch: ColumnarBatch): ColumnarBatch = {
withResource(batch) { _ =>
val transformedCols = GpuColumnVector.extractColumns(batch).safeMap { cv =>
new GpuColumnVector(cv.dataType, deepTransformColumn(cv.getBase, cv.dataType))
.asInstanceOf[org.apache.spark.sql.vectorized.ColumnVector]
}
new ColumnarBatch(transformedCols)
}
Some(new ColumnarBatch(transformedCols))
}

private def deepTransformColumn(cv: ColumnVector, dt: DataType): ColumnVector = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,13 @@ object RapidsBufferCatalog extends Logging {
closeImpl()
}

/**
* Only used in unit tests, it returns the number of buffers in the catalog.
*/
def numBuffers: Int = {
_singleton.numBuffers
}

private def closeImpl(): Unit = synchronized {
if (_singleton != null) {
_singleton.close()
Expand Down
Loading

0 comments on commit 5421a85

Please sign in to comment.