From 5421a85e0fa34d99e4fa3d2038956ff468e12b78 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Wed, 26 Jul 2023 18:23:47 -0500 Subject: [PATCH] Make state spillable in partitioned writer [databricks] (#8667) * Make state spillable in partitioned writer Signed-off-by: Alessandro Bellina --------- Signed-off-by: Alessandro Bellina --- .../spark/rapids/ColumnarOutputWriter.scala | 208 +++---- .../spark/rapids/GpuParquetFileFormat.scala | 22 +- .../spark/rapids/RapidsBufferCatalog.scala | 7 + .../hive/rapids/GpuHiveTextFileFormat.scala | 60 +- .../BasicColumnarWriteStatsTracker.scala | 4 +- .../rapids/ColumnarWriteStatsTracker.scala | 5 +- .../sql/rapids/GpuFileFormatDataWriter.scala | 326 +++++----- .../spark/sql/rapids/GpuOrcFileFormat.scala | 4 + .../sql/rapids/GpuFileFormatWriter.scala | 2 +- .../shims/SparkUpgradeExceptionShims.scala | 5 + .../shims/SparkUpgradeExceptionShims.scala | 6 + .../shims/SparkUpgradeExceptionShims.scala | 6 + .../sql/rapids/GpuFileFormatWriter.scala | 2 +- .../shims/SparkUpgradeExceptionShims.scala | 6 + .../com/nvidia/spark/rapids/CastOpSuite.scala | 8 +- .../DeviceMemoryEventHandlerSuite.scala | 23 +- .../spark/rapids/ParquetWriterSuite.scala | 56 +- .../spark/rapids/RmmSparkRetrySuiteBase.scala | 5 + .../rapids/SparkQueryCompareTestSuite.scala | 28 +- .../shuffle/RapidsShuffleTestHelper.scala | 12 +- .../rapids/GpuFileFormatDataWriterSuite.scala | 578 ++++++++++++++++++ .../filecache/FileCacheIntegrationSuite.scala | 16 +- 22 files changed, 1032 insertions(+), 357 deletions(-) create mode 100644 tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala index 6d8f078ec81..278b19bb661 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala @@ -20,9 +20,11 @@ import java.io.OutputStream import scala.collection.mutable -import ai.rapids.cudf.{HostBufferConsumer, HostMemoryBuffer, NvtxColor, NvtxRange, Table, TableWriter} -import com.nvidia.spark.rapids.Arm.withResource +import ai.rapids.cudf.{HostBufferConsumer, HostMemoryBuffer, NvtxColor, NvtxRange, TableWriter} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRestoreOnRetry, withRetry, withRetryNoSplit} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FSDataOutputStream, Path} import org.apache.hadoop.mapreduce.TaskAttemptContext @@ -38,6 +40,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch * `org.apache.spark.sql.execution.datasources.OutputWriterFactory`. */ abstract class ColumnarOutputWriterFactory extends Serializable { + /** Returns the default partition flush size in bytes, format specific */ + def partitionFlushSize(context: TaskAttemptContext): Long = 128L * 1024L * 1024L // 128M /** Returns the file extension to be used when writing files out. */ def getFileExtension(context: TaskAttemptContext): String @@ -67,14 +71,20 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext, rangeName: String, includeRetry: Boolean) extends HostBufferConsumer { - val tableWriter: TableWriter - val conf = context.getConfiguration + protected val tableWriter: TableWriter - private[this] val outputStream: FSDataOutputStream = { + protected val conf: Configuration = context.getConfiguration + + // This is implemented as a method to make it easier to subclass + // ColumnarOutputWriter in the tests, and override this behavior. + protected def getOutputStream: FSDataOutputStream = { val hadoopPath = new Path(path) val fs = hadoopPath.getFileSystem(conf) fs.create(hadoopPath, false) } + + protected val outputStream: FSDataOutputStream = getOutputStream + private[this] val tempBuffer = new Array[Byte](128 * 1024) private[this] var anythingWritten = false private[this] val buffers = mutable.Queue[(HostMemoryBuffer, Long)]() @@ -93,146 +103,103 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext, true } - /** - * Persists a columnar batch. Invoked on the executor side. When writing to dynamically - * partitioned tables, dynamic partition columns are not included in columns to be written. - * - * NOTE: This method will close `batch`. We do this because we want - * to free GPU memory after the GPU has finished encoding the data but before - * it is written to the distributed filesystem. The GPU semaphore is released - * during the distributed filesystem transfer to allow other tasks to start/continue - * GPU processing. - */ - def writeAndClose( - batch: ColumnarBatch, + private[this] def updateStatistics( + writeStartTime: Long, + gpuTime: Long, statsTrackers: Seq[ColumnarWriteTaskStatsTracker]): Unit = { - var needToCloseBatch = true - try { - val writeStartTimestamp = System.nanoTime - val writeRange = new NvtxRange("File write", NvtxColor.YELLOW) - val gpuTime = try { - needToCloseBatch = false - writeBatch(batch) - } finally { - writeRange.close() - } - - // Update statistics - val writeTime = System.nanoTime - writeStartTimestamp - gpuTime - statsTrackers.foreach { - case gpuTracker: GpuWriteTaskStatsTracker => - gpuTracker.addWriteTime(writeTime) - gpuTracker.addGpuTime(gpuTime) - case _ => - } - } finally { - if (needToCloseBatch) { - batch.close() - } + // Update statistics + val writeTime = System.nanoTime - writeStartTime - gpuTime + statsTrackers.foreach { + case gpuTracker: GpuWriteTaskStatsTracker => + gpuTracker.addWriteTime(writeTime) + gpuTracker.addGpuTime(gpuTime) + case _ => } } - protected def scanTableBeforeWrite(table: Table): Unit = { + protected def throwIfRebaseNeededInExceptionMode(batch: ColumnarBatch): Unit = { // NOOP for now, but allows a child to override this } + /** - * Writes the columnar batch and returns the time in ns taken to write + * Persists a columnar batch. Invoked on the executor side. When writing to dynamically + * partitioned tables, dynamic partition columns are not included in columns to be written. * - * NOTE: This method will close `batch`. We do this because we want + * NOTE: This method will close `spillableBatch`. We do this because we want * to free GPU memory after the GPU has finished encoding the data but before * it is written to the distributed filesystem. The GPU semaphore is released * during the distributed filesystem transfer to allow other tasks to start/continue * GPU processing. - * - * @param batch Columnar batch that needs to be written - * @return time in ns taken to write the batch */ - private[this] def writeBatch(batch: ColumnarBatch): Long = { - if (includeRetry) { - writeBatchWithRetry(batch) - } else { - writeBatchNoRetry(batch) - } - } - - /** Apply any necessary casts before writing batch out */ - def transform(cb: ColumnarBatch): Option[ColumnarBatch] = None - - private[this] def writeBatchWithRetry(batch: ColumnarBatch): Long = { - val sb = SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY) - RmmRapidsRetryIterator.withRetry(sb, RmmRapidsRetryIterator.splitSpillableInHalfByRows) { sb => - val cr = new CheckpointRestore { - override def checkpoint(): Unit = () - override def restore(): Unit = dropBufferedData() + def writeSpillableAndClose( + spillableBatch: SpillableColumnarBatch, + statsTrackers: Seq[ColumnarWriteTaskStatsTracker]): Long = { + val writeStartTime = System.nanoTime + closeOnExcept(spillableBatch) { _ => + val cb = withRetryNoSplit[ColumnarBatch] { + spillableBatch.getColumnarBatch() } - val startTimestamp = System.nanoTime - withResource(sb.getColumnarBatch()) { cb => - //TODO: we should really apply the transformations to cast timestamps - // to the expected types before spilling but we need a SpillableTable - // rather than a SpillableColumnBatch to be able to do that - // See https://github.com/NVIDIA/spark-rapids/issues/8262 - RmmRapidsRetryIterator.withRestoreOnRetry(cr) { - withResource(new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)) { _ => - scan(cb) - transform(cb) match { - case Some(transformed) => - // because we created a new transformed batch, we need to make sure we close it - withResource(transformed) { _ => - write(transformed) - } - case _ => - write(cb) - } - } - } + // run pre-flight checks and update stats + withResource(cb) { _ => + throwIfRebaseNeededInExceptionMode(cb) + // NOTE: it is imperative that `newBatch` is not in a retry block. + // Otherwise it WILL corrupt writers that generate metadata in this method (like delta) + statsTrackers.foreach(_.newBatch(path(), cb)) } - GpuSemaphore.releaseIfNecessary(TaskContext.get) - val gpuTime = System.nanoTime - startTimestamp - writeBufferedData() - gpuTime - }.sum - } - - private[this] def writeBatchNoRetry(batch: ColumnarBatch): Long = { - var needToCloseBatch = true - try { - val startTimestamp = System.nanoTime - withResource(new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)) { _ => - scan(batch) - transform(batch) match { - case Some(transformed) => - // because we created a new transformed batch, we need to make sure we close it - withResource(transformed) { _ => - write(transformed) - } - case _ => - write(batch) + } + val gpuTime = if (includeRetry) { + //TODO: we should really apply the transformations to cast timestamps + // to the expected types before spilling but we need a SpillableTable + // rather than a SpillableColumnBatch to be able to do that + // See https://github.com/NVIDIA/spark-rapids/issues/8262 + withRetry(spillableBatch, splitSpillableInHalfByRows) { attempt => + withRestoreOnRetry(checkpointRestore) { + bufferBatchAndClose(attempt.getColumnarBatch()) } + }.sum + } else { + withResource(spillableBatch) { _ => + bufferBatchAndClose(spillableBatch.getColumnarBatch()) } + } + // we successfully buffered to host memory, release the semaphore and write + // the buffered data to the FS + GpuSemaphore.releaseIfNecessary(TaskContext.get) + writeBufferedData() + updateStatistics(writeStartTime, gpuTime, statsTrackers) + spillableBatch.numRows() + } - // Batch is no longer needed, write process from here does not use GPU. - batch.close() - needToCloseBatch = false - GpuSemaphore.releaseIfNecessary(TaskContext.get) - val gpuTime = System.nanoTime - startTimestamp - writeBufferedData() - gpuTime - } finally { - if (needToCloseBatch) { - batch.close() + // protected for testing + protected[this] def bufferBatchAndClose(batch: ColumnarBatch): Long = { + val startTimestamp = System.nanoTime + withResource(new NvtxRange(s"GPU $rangeName write", NvtxColor.BLUE)) { _ => + withResource(transformAndClose(batch)) { maybeTransformed => + encodeAndBufferToHost(maybeTransformed) } } + // time spent on GPU encoding to the host sink + System.nanoTime - startTimestamp } - private def scan(batch: ColumnarBatch): Unit = { - withResource(GpuColumnVector.from(batch)) { table => - scanTableBeforeWrite(table) - } + /** Apply any necessary casts before writing batch out */ + def transformAndClose(cb: ColumnarBatch): ColumnarBatch = cb + + private val checkpointRestore = new CheckpointRestore { + override def checkpoint(): Unit = () + override def restore(): Unit = dropBufferedData() } - private def write(batch: ColumnarBatch): Unit = { + private def encodeAndBufferToHost(batch: ColumnarBatch): Unit = { withResource(GpuColumnVector.from(batch)) { table => + // `anythingWritten` is set here as an indication that there was data at all + // to write, even if the `tableWriter.write` method fails. If we fail to write + // and the task fails, any output is going to be discarded anyway, so no data + // corruption to worry about. Otherwise, we should retry (OOM case). + // If we have nothing to write, we won't flip this flag to true and we will + // buffer an empty batch on close() to work around issues in cuDF + // where corrupt files can be written if nothing is encoded via the writer. anythingWritten = true tableWriter.write(table) } @@ -245,9 +212,10 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext, def close(): Unit = { if (!anythingWritten) { // This prevents writing out bad files - writeBatch(GpuColumnVector.emptyBatch(dataSchema)) + bufferBatchAndClose(GpuColumnVector.emptyBatch(dataSchema)) } tableWriter.close() + GpuSemaphore.releaseIfNecessary(TaskContext.get()) writeBufferedData() outputStream.close() } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index a9a36ca0f5d..71d47e99196 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -290,6 +290,10 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging { override def getFileExtension(context: TaskAttemptContext): String = { CodecConfig.from(context).getCodec.getExtension + ".parquet" } + + override def partitionFlushSize(context: TaskAttemptContext): Long = + context.getConfiguration.getLong("write.parquet.row-group-size-bytes", + 128L * 1024L * 1024L) // 128M } } } @@ -306,9 +310,9 @@ class GpuParquetWriter( val outputTimestampType = conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key) - override def scanTableBeforeWrite(table: Table): Unit = { - (0 until table.getNumberOfColumns).foreach { i => - val col = table.getColumn(i) + override def throwIfRebaseNeededInExceptionMode(batch: ColumnarBatch): Unit = { + val cols = GpuColumnVector.extractBases(batch) + cols.foreach { col => // if col is a day if (dateRebaseException && RebaseHelper.isDateRebaseNeededInWrite(col)) { throw DataSourceUtils.newRebaseExceptionInWrite("Parquet") @@ -320,12 +324,14 @@ class GpuParquetWriter( } } - override def transform(batch: ColumnarBatch): Option[ColumnarBatch] = { - val transformedCols = GpuColumnVector.extractColumns(batch).safeMap { cv => - new GpuColumnVector(cv.dataType, deepTransformColumn(cv.getBase, cv.dataType)) - .asInstanceOf[org.apache.spark.sql.vectorized.ColumnVector] + override def transformAndClose(batch: ColumnarBatch): ColumnarBatch = { + withResource(batch) { _ => + val transformedCols = GpuColumnVector.extractColumns(batch).safeMap { cv => + new GpuColumnVector(cv.dataType, deepTransformColumn(cv.getBase, cv.dataType)) + .asInstanceOf[org.apache.spark.sql.vectorized.ColumnVector] + } + new ColumnarBatch(transformedCols) } - Some(new ColumnarBatch(transformedCols)) } private def deepTransformColumn(cv: ColumnVector, dt: DataType): ColumnVector = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index aec7e7ca22e..724b71fe631 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -786,6 +786,13 @@ object RapidsBufferCatalog extends Logging { closeImpl() } + /** + * Only used in unit tests, it returns the number of buffers in the catalog. + */ + def numBuffers: Int = { + _singleton.numBuffers + } + private def closeImpl(): Unit = synchronized { if (_singleton != null) { _singleton.close() diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveTextFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveTextFileFormat.scala index a2f47749d70..5deb0772f59 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveTextFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/hive/rapids/GpuHiveTextFileFormat.scala @@ -137,40 +137,42 @@ class GpuHiveTextWriter(override val path: String, * This writer currently reformats timestamp and floating point * columns. */ - override def transform(cb: ColumnarBatch): Option[ColumnarBatch] = { - withResource(GpuColumnVector.from(cb)) { table => - val columns = for (i <- 0 until table.getNumberOfColumns) yield { - table.getColumn(i) match { - case c if c.getType.hasTimeResolution => - // By default, the CUDF CSV writer writes timestamps in the following format: - // "2020-09-16T22:32:01.123456Z" - // Hive's LazySimpleSerDe format expects timestamps to be formatted thus: - // "uuuu-MM-dd HH:mm:ss[.SSS...]" - // (Specifically, no `T` between `dd` and `HH`, and no `Z` at the end.) - val col = withResource(c.asStrings("%Y-%m-%d %H:%M:%S.%f")) { asStrings => - withResource(Scalar.fromString("\\N")) { nullString => - asStrings.replaceNulls(nullString) + override def transformAndClose(cb: ColumnarBatch): ColumnarBatch = { + withResource(cb) { _ => + withResource(GpuColumnVector.from(cb)) { table => + val columns = for (i <- 0 until table.getNumberOfColumns) yield { + table.getColumn(i) match { + case c if c.getType.hasTimeResolution => + // By default, the CUDF CSV writer writes timestamps in the following format: + // "2020-09-16T22:32:01.123456Z" + // Hive's LazySimpleSerDe format expects timestamps to be formatted thus: + // "uuuu-MM-dd HH:mm:ss[.SSS...]" + // (Specifically, no `T` between `dd` and `HH`, and no `Z` at the end.) + val col = withResource(c.asStrings("%Y-%m-%d %H:%M:%S.%f")) { asStrings => + withResource(Scalar.fromString("\\N")) { nullString => + asStrings.replaceNulls(nullString) + } } - } - GpuColumnVector.from(col, StringType) - case c if c.getType == DType.FLOAT32 || c.getType == DType.FLOAT64 => - // By default, the CUDF CSV writer writes floats with value `Infinity` - // as `"Inf"`. - // Hive's LazySimplSerDe expects such values to be written as `"Infinity"`. - // All occurrences of `Inf` need to be replaced with `Infinity`. - val col = withResource(c.castTo(DType.STRING)) { asStrings => - withResource(Scalar.fromString("Inf")) { infString => - withResource(Scalar.fromString("Infinity")) { infinityString => - asStrings.stringReplace(infString, infinityString) + GpuColumnVector.from(col, StringType) + case c if c.getType == DType.FLOAT32 || c.getType == DType.FLOAT64 => + // By default, the CUDF CSV writer writes floats with value `Infinity` + // as `"Inf"`. + // Hive's LazySimplSerDe expects such values to be written as `"Infinity"`. + // All occurrences of `Inf` need to be replaced with `Infinity`. + val col = withResource(c.castTo(DType.STRING)) { asStrings => + withResource(Scalar.fromString("Inf")) { infString => + withResource(Scalar.fromString("Infinity")) { infinityString => + asStrings.stringReplace(infString, infinityString) + } } } - } - GpuColumnVector.from(col, StringType) - case c => - GpuColumnVector.from(c.incRefCount(), cb.column(i).dataType()) + GpuColumnVector.from(col, StringType) + case c => + GpuColumnVector.from(c.incRefCount(), cb.column(i).dataType()) + } } + new ColumnarBatch(columns.toArray, cb.numRows()) } - Some(new ColumnarBatch(columns.toArray, cb.numRows())) } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/BasicColumnarWriteStatsTracker.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/BasicColumnarWriteStatsTracker.scala index 4178fbefaf8..7d2f6f90036 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/BasicColumnarWriteStatsTracker.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/BasicColumnarWriteStatsTracker.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -154,7 +154,7 @@ class BasicColumnarWriteTaskStatsTracker( } override def newBatch(filePath: String, batch: ColumnarBatch): Unit = { - numRows += batch.numRows + numRows += batch.numRows() } override def getFinalStats(taskCommitTime: Long): WriteTaskStats = { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ColumnarWriteStatsTracker.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ColumnarWriteStatsTracker.scala index fc90f6afd0a..54269bb04dd 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ColumnarWriteStatsTracker.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/ColumnarWriteStatsTracker.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,6 +52,9 @@ trait ColumnarWriteTaskStatsTracker { /** * Process a new column batch to update the tracked statistics accordingly. * The batch will be written to the most recently witnessed file (via `newFile`). + * @note Call this function only once per `batch` to be written. If the batch is going to be + * split later because of a retry, that is OK, but don't call newBatch again with the + * splitted out parts. * @param filePath Path of the file which the batch is written to. * @param batch Current data batch to be processed. */ diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala index 37697382eba..4193652405d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.rapids import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import ai.rapids.cudf.{ColumnVector, ContiguousTable, OrderByArg, Table} +import ai.rapids.cudf.{ColumnVector, OrderByArg, Table} import com.nvidia.spark.TimingUtils import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.RmmRapidsRetryIterator.withRetryNoSplit import com.nvidia.spark.rapids.shims.GpuFileFormatDataWriterShim import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.TaskContext -import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -70,7 +70,7 @@ object GpuFileFormatDataWriter { /** * Split a table into parts if recordsInFile + batch row count would go above - * maxRecordsPerFile. + * maxRecordsPerFile and make the splits spillable. * * The logic to find out what the splits should be is delegated to getSplitIndexes. * @@ -80,12 +80,12 @@ object GpuFileFormatDataWriter { * @param batch ColumnarBatch to split (and close) * @param maxRecordsPerFile max rowcount per file * @param recordsInFile row count in the file so far - * @return array of ColumnarBatch splits + * @return array of SpillableColumnarBatch splits */ def splitToFitMaxRecordsAndClose( batch: ColumnarBatch, maxRecordsPerFile: Long, - recordsInFile: Long): Array[ColumnarBatch] = { + recordsInFile: Long): Array[SpillableColumnarBatch] = { val (types, splitIndexes) = closeOnExcept(batch) { _ => val types = GpuColumnVector.extractTypes(batch) val splitIndexes = @@ -98,7 +98,7 @@ object GpuFileFormatDataWriter { if (splitIndexes.isEmpty) { // this should never happen, as `splitToFitMaxRecordsAndClose` is called when // splits should already happen, but making it more efficient in that case - Seq(batch).toArray + Array(SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) } else { // actually split it val tbl = withResource(batch) { _ => @@ -108,7 +108,8 @@ object GpuFileFormatDataWriter { tbl.contiguousSplit(splitIndexes: _*) } withResource(cts) { _ => - cts.safeMap(ct => GpuColumnVector.from(ct.getTable, types)) + cts.safeMap(ct => + SpillableColumnarBatch(ct, types, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) } } } @@ -147,8 +148,8 @@ abstract class GpuFileFormatDataWriter( } } - /** Release all resources. */ - protected def releaseResources(): Unit = { + /** Release all resources. Public for testing */ + def releaseResources(): Unit = { // Call `releaseCurrentWriter()` by default, as this is the only resource to be released. releaseCurrentWriter() } @@ -235,16 +236,19 @@ class GpuSingleDirectoryDataWriter( statsTrackers.foreach(_.newFile(currentPath)) } + private def writeUpdateMetricsAndClose(scb: SpillableColumnarBatch): Unit = { + recordsInFile += currentWriter.writeSpillableAndClose(scb, statsTrackers) + } + override def write(batch: ColumnarBatch): Unit = { val maxRecordsPerFile = description.maxRecordsPerFile - if (!shouldSplitToFitMaxRecordsPerFile(maxRecordsPerFile, recordsInFile, batch.numRows())) { - closeOnExcept(batch) { _ => - statsTrackers.foreach(_.newBatch(currentWriter.path(), batch)) - recordsInFile += batch.numRows() - } - currentWriter.writeAndClose(batch, statsTrackers) + if (!shouldSplitToFitMaxRecordsPerFile( + maxRecordsPerFile, recordsInFile, batch.numRows())) { + writeUpdateMetricsAndClose( + SpillableColumnarBatch(batch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)) } else { - val partBatches = splitToFitMaxRecordsAndClose(batch, maxRecordsPerFile, recordsInFile) + val partBatches = splitToFitMaxRecordsAndClose( + batch, maxRecordsPerFile, recordsInFile) var needNewWriter = recordsInFile >= maxRecordsPerFile closeOnExcept(partBatches) { _ => partBatches.zipWithIndex.foreach { case (partBatch, partIx) => @@ -254,11 +258,9 @@ class GpuSingleDirectoryDataWriter( s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") newOutputWriter() } - statsTrackers.foreach(_.newBatch(currentWriter.path(), partBatch)) - recordsInFile += partBatch.numRows() // null out the entry so that we don't double close partBatches(partIx) = null - currentWriter.writeAndClose(partBatch, statsTrackers) + writeUpdateMetricsAndClose(partBatch) needNewWriter = true } } @@ -331,26 +333,22 @@ class GpuDynamicPartitionDataSingleWriter( """.stripMargin) /** Extracts the partition values out of an input batch. */ - protected lazy val getPartitionColumnsAsTable: ColumnarBatch => Table = { + protected lazy val getPartitionColumnsAsBatch: ColumnarBatch => ColumnarBatch = { val expressions = GpuBindReferences.bindGpuReferences( description.partitionColumns, description.allColumns) cb => { - withResource(GpuProjectExec.project(cb, expressions)) { batch => - GpuColumnVector.from(batch) - } + GpuProjectExec.project(cb, expressions) } } /** Extracts the output values of an input batch. */ - private lazy val getOutputColumnsAsTable: ColumnarBatch => Table = { + private lazy val getOutputColumnsAsBatch: ColumnarBatch => ColumnarBatch= { val expressions = GpuBindReferences.bindGpuReferences( description.dataColumns, description.allColumns) cb => { - withResource(GpuProjectExec.project(cb, expressions)) { batch => - GpuColumnVector.from(batch) - } + GpuProjectExec.project(cb, expressions) } } @@ -407,7 +405,7 @@ class GpuDynamicPartitionDataSingleWriter( @scala.annotation.nowarn( "msg=method newTaskTempFile.* in class FileCommitProtocol is deprecated" ) - protected def newWriter( + def newWriter( partDir: String, bucketId: Option[Int], // Currently it's always None fileCounter: Int @@ -462,15 +460,16 @@ class GpuDynamicPartitionDataSingleWriter( } } - override def write(cb: ColumnarBatch): Unit = { + override def write(batch: ColumnarBatch): Unit = { // this single writer always passes `cachesMap` as None - write(cb, cachesMap = None) + write(batch, cachesMap = None) } - private case class SplitAndPath(split: ContiguousTable, path: String, partIx: Int) + private case class SplitAndPath(var split: SpillableColumnarBatch, path: String) extends AutoCloseable { override def close(): Unit = { - split.close() + split.safeClose() + split = null } } @@ -479,17 +478,20 @@ class GpuDynamicPartitionDataSingleWriter( * array of the splits as `ContiguousTable`'s, and an array of paths to use to * write each partition. */ - private def splitBatchByKey( + private def splitBatchByKeyAndClose( batch: ColumnarBatch, partDataTypes: Array[DataType]): Array[SplitAndPath] = { - val (outputColumnsTbl, partitionColumnsTbl) = withResource(batch) { _ => - closeOnExcept(getOutputColumnsAsTable(batch)) { outputColumnsTbl => - closeOnExcept(getPartitionColumnsAsTable(batch)) { partitionColumnsTbl => - (outputColumnsTbl, partitionColumnsTbl) + val (outputColumnsBatch, partitionColumnsBatch) = withResource(batch) { _ => + closeOnExcept(getOutputColumnsAsBatch(batch)) { outputColumnsBatch => + closeOnExcept(getPartitionColumnsAsBatch(batch)) { partitionColumnsBatch => + (outputColumnsBatch, partitionColumnsBatch) } } } - val (cbKeys, partitionIndexes) = closeOnExcept(outputColumnsTbl) { _ => + val (cbKeys, partitionIndexes) = closeOnExcept(outputColumnsBatch) { _ => + val partitionColumnsTbl = withResource(partitionColumnsBatch) { _ => + GpuColumnVector.from(partitionColumnsBatch) + } withResource(partitionColumnsTbl) { _ => withResource(distinctAndSort(partitionColumnsTbl)) { distinctKeysTbl => val partitionIndexes = splitIndexes(partitionColumnsTbl, distinctKeysTbl) @@ -498,11 +500,21 @@ class GpuDynamicPartitionDataSingleWriter( } } } + val splits = closeOnExcept(cbKeys) { _ => - withResource(outputColumnsTbl) { _ => - outputColumnsTbl.contiguousSplit(partitionIndexes: _*) + val spillableOutputColumnsBatch = + SpillableColumnarBatch(outputColumnsBatch, SpillPriorities.ACTIVE_ON_DECK_PRIORITY) + withRetryNoSplit(spillableOutputColumnsBatch) { spillable => + withResource(spillable.getColumnarBatch()) { outCb => + withResource(GpuColumnVector.from(outCb)) { outputColumnsTbl => + withResource(outputColumnsTbl) { _ => + outputColumnsTbl.contiguousSplit(partitionIndexes: _*) + } + } + } } } + val paths = closeOnExcept(splits) { _ => withResource(cbKeys) { _ => // Use the existing code to convert each row into a path. It would be nice to do this @@ -517,10 +529,53 @@ class GpuDynamicPartitionDataSingleWriter( // NOTE: the `zip` here has the effect that will remove an extra `ContiguousTable` // added at the end of `splits` because we use `upperBound` to find the split points, // and the last split point is the number of rows. + val outDataTypes = description.dataColumns.map(_.dataType).toArray splits.zip(paths).zipWithIndex.map { case ((split, path), ix) => splits(ix) = null - SplitAndPath(split, path, ix) + withResource(split) { _ => + SplitAndPath( + SpillableColumnarBatch( + split, outDataTypes, SpillPriorities.ACTIVE_BATCHING_PRIORITY), + path) + } + } + } + } + + private def getBatchToWrite( + partBatch: SpillableColumnarBatch, + savedStatus: Option[WriterStatusWithCaches]): SpillableColumnarBatch = { + val outDataTypes = description.dataColumns.map(_.dataType).toArray + if (savedStatus.isDefined && savedStatus.get.tableCaches.nonEmpty) { + // In the case where the concurrent partition writers fall back, we need to + // incorporate into the current part any pieces that are already cached + // in the `savedStatus`. Adding `partBatch` to what was saved could make a + // concatenated batch with number of rows larger than `maxRecordsPerFile`, + // so this concatenated result could be split later, which is not efficient. However, + // the concurrent writers are default off in Spark, so it is not clear if this + // code path is worth optimizing. + val concat = + withResource(savedStatus.get.tableCaches) { subSpillableBatches => + val toConcat = subSpillableBatches :+ partBatch + + // clear the caches + savedStatus.get.tableCaches.clear() + + withRetryNoSplit(toConcat) { spillables => + withResource(spillables.safeMap(_.getColumnarBatch())) { batches => + withResource(batches.map(GpuColumnVector.from)) { subTables => + Table.concatenate(subTables: _*) + } + } + } + } + withResource(concat) { _ => + SpillableColumnarBatch( + GpuColumnVector.from(concat, outDataTypes), + SpillPriorities.ACTIVE_ON_DECK_PRIORITY) } + } else { + partBatch } } @@ -535,54 +590,39 @@ class GpuDynamicPartitionDataSingleWriter( * writer, single writer should handle the stored writers and the pending caches */ protected def write( - cb: ColumnarBatch, + batch: ColumnarBatch, cachesMap: Option[mutable.HashMap[String, WriterStatusWithCaches]]): Unit = { assert(isPartitioned) assert(!isBucketed) val maxRecordsPerFile = description.maxRecordsPerFile val partDataTypes = description.partitionColumns.map(_.dataType).toArray - val outDataTypes = description.dataColumns.map(_.dataType).toArray // We have an entire batch that is sorted, so we need to split it up by key // to get a batch per path - val splitsAndPaths = splitBatchByKey(cb, partDataTypes) - withResource(splitsAndPaths) { _ => - splitsAndPaths.foreach { case SplitAndPath(partContigTable, partPath, partIx) => - // If fall back from for `GpuDynamicPartitionDataConcurrentWriter`, we should get the + withResource(splitBatchByKeyAndClose(batch, partDataTypes)) { splitsAndPaths => + splitsAndPaths.zipWithIndex.foreach { case (SplitAndPath(partBatch, partPath), ix) => + // If we fall back from `GpuDynamicPartitionDataConcurrentWriter`, we should get the // saved status val savedStatus = updateCurrentWriterIfNeeded(partPath, cachesMap) - val batchToWrite = - if (savedStatus.isDefined && savedStatus.get.tableCaches.nonEmpty) { - // convert caches seq to tables and close caches seq - val subTables = convertSpillBatchesToTablesAndClose(savedStatus.get.tableCaches) - // concat the caches and this `table` - val concat = withResource(subTables) { _ => - // clear the caches - savedStatus.get.tableCaches.clear() - splitsAndPaths(partIx) = null - withResource(partContigTable) { _ => - subTables += partContigTable.getTable - Table.concatenate(subTables: _*) - } - } - withResource(concat) { _ => - GpuColumnVector.from(concat, outDataTypes) - } - } else { - splitsAndPaths(partIx) = null - withResource(partContigTable) { _ => - GpuColumnVector.from(partContigTable.getTable, outDataTypes) - } - } + + // combine `partBatch` with any remnants for this partition for the concurrent + // writer fallback case in `savedStatus` + splitsAndPaths(ix) = null + val batchToWrite = getBatchToWrite(partBatch, savedStatus) // if the batch fits, write it as is, else split and write it. if (!shouldSplitToFitMaxRecordsPerFile(maxRecordsPerFile, currentWriterStatus.recordsInFile, batchToWrite.numRows())) { - writeBatchUsingCurrentWriterAndClose(batchToWrite) + writeUpdateMetricsAndClose(currentWriterStatus, batchToWrite) } else { + // materialize an actual batch since we are going to split it + // on the GPU + val batchToSplit = withRetryNoSplit(batchToWrite) { _ => + batchToWrite.getColumnarBatch() + } val maxRecordsPerFileSplits = splitToFitMaxRecordsAndClose( - batchToWrite, + batchToSplit, maxRecordsPerFile, currentWriterStatus.recordsInFile) writeSplitBatchesAndClose(maxRecordsPerFileSplits, maxRecordsPerFile, partPath) @@ -632,21 +672,21 @@ class GpuDynamicPartitionDataSingleWriter( } /** - * Write an array of batches. + * Write an array of spillable batches. * - * Note: `batches` will be closed in this function. + * Note: `spillableBatches` will be closed in this function. * - * @param batches the ColumnarBatch splits to be written + * @param batches the SpillableColumnarBatch splits to be written * @param maxRecordsPerFile the max number of rows per file * @param partPath the partition directory */ private def writeSplitBatchesAndClose( - batches: Array[ColumnarBatch], + spillableBatches: Array[SpillableColumnarBatch], maxRecordsPerFile: Long, partPath: String): Unit = { var needNewWriter = currentWriterStatus.recordsInFile >= maxRecordsPerFile - withResource(batches) { _ => - batches.zipWithIndex.foreach { case (part, partIx) => + withResource(spillableBatches) { _ => + spillableBatches.zipWithIndex.foreach { case (part, partIx) => if (needNewWriter) { currentWriterStatus.fileCounter += 1 assert(currentWriterStatus.fileCounter <= MAX_FILE_COUNTER, @@ -663,19 +703,18 @@ class GpuDynamicPartitionDataSingleWriter( newWriter(partPath, None, currentWriterStatus.fileCounter) currentWriterStatus.recordsInFile = 0 } - batches(partIx) = null - writeBatchUsingCurrentWriterAndClose(part) + spillableBatches(partIx) = null + writeUpdateMetricsAndClose(currentWriterStatus, part) needNewWriter = true } } } - private def writeBatchUsingCurrentWriterAndClose(batch: ColumnarBatch): Unit = { - closeOnExcept(batch) { _ => - statsTrackers.foreach(_.newBatch(currentWriterStatus.outputWriter.path(), batch)) - currentWriterStatus.recordsInFile += batch.numRows() - } - currentWriterStatus.outputWriter.writeAndClose(batch, statsTrackers) + protected def writeUpdateMetricsAndClose( + writerStatus: WriterStatus, + spillableBatch: SpillableColumnarBatch): Unit = { + writerStatus.recordsInFile += + writerStatus.outputWriter.writeSpillableAndClose(spillableBatch, statsTrackers) } /** Release all resources. */ @@ -692,26 +731,6 @@ class GpuDynamicPartitionDataSingleWriter( } } } - - /** - * convert spillable columnar batch seq to tables and close the input `spills` - * - * @param spills spillable columnar batch seq - * @return table array - */ - def convertSpillBatchesToTablesAndClose( - spills: Seq[SpillableColumnarBatch]): ArrayBuffer[Table] = { - withResource(spills) { _ => - val subTablesBuffer = new ArrayBuffer[Table] - spills.foreach { spillableCb => - withResource(spillableCb.getColumnarBatch()) { cb => - val currTable = GpuColumnVector.from(cb) - subTablesBuffer += currTable - } - } - subTablesBuffer - } - } } /** @@ -734,39 +753,31 @@ class GpuDynamicPartitionDataConcurrentWriter( description: GpuWriteJobDescription, taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol, - spec: GpuConcurrentOutputWriterSpec) - extends GpuDynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) - with Logging { + spec: GpuConcurrentOutputWriterSpec, + taskContext: TaskContext) + extends GpuDynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) { // Keep all the unclosed writers, key is partition directory string. // Note: if fall back to sort-based mode, also use the opened writers in the map. private val concurrentWriters = mutable.HashMap[String, WriterStatusWithCaches]() // guarantee to close the caches and writers when task is finished - TaskContext.get().addTaskCompletionListener[Unit](_ => closeCachesAndWriters()) + taskContext.addTaskCompletionListener[Unit](_ => closeCachesAndWriters()) private val outDataTypes = description.dataColumns.map(_.dataType).toArray - val partitionFlushSize = if (description.concurrentWriterPartitionFlushSize <= 0) { - // if the property is equal or less than 0, use default value of parquet or orc - val extension = description.outputWriterFactory - .getFileExtension(taskAttemptContext).toLowerCase() - if (extension.endsWith("parquet")) { - taskAttemptContext.getConfiguration.getLong("write.parquet.row-group-size-bytes", - 128L * 1024L * 1024L) // 128M - } else if (extension.endsWith("orc")) { - taskAttemptContext.getConfiguration.getLong("orc.stripe.size", - 64L * 1024L * 1024L) // 64M + private val partitionFlushSize = + if (description.concurrentWriterPartitionFlushSize <= 0) { + // if the property is equal or less than 0, use default value given by the + // writer factory + description.outputWriterFactory.partitionFlushSize(taskAttemptContext) } else { - 128L * 1024L * 1024L // 128M + // if the property is greater than 0, use the property value + description.concurrentWriterPartitionFlushSize } - } else { - // if the property is greater than 0, use the property value - description.concurrentWriterPartitionFlushSize - } // refer to current batch if should fall back to `single writer` - var currentFallbackColumnarBatch: ColumnarBatch = _ + private var currentFallbackColumnarBatch: ColumnarBatch = _ override def abort(): Unit = { try { @@ -782,14 +793,14 @@ class GpuDynamicPartitionDataConcurrentWriter( */ private var fallBackToSortBased: Boolean = false - def writeWithSingleWriter(cb: ColumnarBatch): Unit = { + private def writeWithSingleWriter(cb: ColumnarBatch): Unit = { // invoke `GpuDynamicPartitionDataSingleWriter`.write, // single writer will take care of the unclosed writers and the pending caches // in `concurrentWriters` super.write(cb, Some(concurrentWriters)) } - def writeWithConcurrentWriter(cb: ColumnarBatch): Unit = { + private def writeWithConcurrentWriter(cb: ColumnarBatch): Unit = { this.write(cb) } @@ -812,6 +823,7 @@ class GpuDynamicPartitionDataConcurrentWriter( // concat the put back batch and un-coming batches val newIterator = Iterator.single(currentFallbackColumnarBatch) ++ iterator // sort the all the batches in `iterator` + val sortIterator: GpuOutOfCoreSortIterator = getSorted(newIterator) while (sortIterator.hasNext) { // write with sort-based single writer @@ -864,12 +876,13 @@ class GpuDynamicPartitionDataConcurrentWriter( // 1. combine partition columns and `cb` columns into a column array val columnsWithPartition = ArrayBuffer[ColumnVector]() - withResource(getPartitionColumnsAsTable(cb)) { partitionColumnsTable => - for (i <- 0 until partitionColumnsTable.getNumberOfColumns) { - // append partition column - columnsWithPartition += partitionColumnsTable.getColumn(i) - } + + // this withResource is here to decrement the refcount of the partition columns + // that are projected out of `cb` + withResource(getPartitionColumnsAsBatch(cb)) { partitionColumnsBatch => + columnsWithPartition.appendAll(GpuColumnVector.extractBases(partitionColumnsBatch)) } + val cols = GpuColumnVector.extractBases(cb) columnsWithPartition ++= cols @@ -980,35 +993,38 @@ class GpuDynamicPartitionDataConcurrentWriter( private def writeAndCloseCache(partitionDir: String, status: WriterStatusWithCaches): Unit = { assert(status.tableCaches.nonEmpty) - // convert spillable caches to tables, and close `status.tableCaches` - val subTables = convertSpillBatchesToTablesAndClose(status.tableCaches) - // get concat table or the single table - val t = if (status.tableCaches.length >= 2) { + val spillableToWrite = if (status.tableCaches.length >= 2) { // concat the sub batches to write in once. - withResource(subTables) { _ => - Table.concatenate(subTables: _*) + val concatted = withRetryNoSplit(status.tableCaches) { spillableSubBatches => + withResource(spillableSubBatches.safeMap(_.getColumnarBatch())) { subBatches => + withResource(subBatches.map(GpuColumnVector.from)) { subTables => + Table.concatenate(subTables: _*) + } + } + } + withResource(concatted) { _ => + SpillableColumnarBatch( + GpuColumnVector.from(concatted, outDataTypes), + SpillPriorities.ACTIVE_ON_DECK_PRIORITY) } } else { // only one single table - subTables.head + status.tableCaches.head } - val maxRecordsPerFile = description.maxRecordsPerFile - val batch = withResource(t) { _ => - GpuColumnVector.from(t, outDataTypes) - } + status.tableCaches.clear() + val maxRecordsPerFile = description.maxRecordsPerFile if (!shouldSplitToFitMaxRecordsPerFile( - maxRecordsPerFile, status.writerStatus.recordsInFile, batch.numRows())) { - closeOnExcept(batch) { _ => - statsTrackers.foreach(_.newBatch(status.writerStatus.outputWriter.path(), batch)) - status.writerStatus.recordsInFile += batch.numRows() - } - status.writerStatus.outputWriter.writeAndClose(batch, statsTrackers) + maxRecordsPerFile, status.writerStatus.recordsInFile, spillableToWrite.numRows())) { + writeUpdateMetricsAndClose(status.writerStatus, spillableToWrite) } else { + val batchToSplit = withRetryNoSplit(spillableToWrite) { _ => + spillableToWrite.getColumnarBatch() + } val splits = splitToFitMaxRecordsAndClose( - batch, + batchToSplit, maxRecordsPerFile, status.writerStatus.recordsInFile) var needNewWriter = status.writerStatus.recordsInFile >= maxRecordsPerFile @@ -1018,19 +1034,15 @@ class GpuDynamicPartitionDataConcurrentWriter( status.writerStatus.fileCounter += 1 assert(status.writerStatus.fileCounter <= MAX_FILE_COUNTER, s"File counter ${status.writerStatus.fileCounter} " + - s"is beyond max value $MAX_FILE_COUNTER") + s"is beyond max value $MAX_FILE_COUNTER") status.writerStatus.outputWriter.close() - // start a new writer val w = newWriter(partitionDir, None, status.writerStatus.fileCounter) status.writerStatus.outputWriter = w status.writerStatus.recordsInFile = 0L } - // close the contiguous table - statsTrackers.foreach(_.newBatch(status.writerStatus.outputWriter.path(), split)) - status.writerStatus.recordsInFile += split.numRows() splits(partIndex) = null - status.writerStatus.outputWriter.writeAndClose(split, statsTrackers) + writeUpdateMetricsAndClose(status.writerStatus, split) needNewWriter = true } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala index fddefc9eb88..354245e2a75 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuOrcFileFormat.scala @@ -171,6 +171,10 @@ class GpuOrcFileFormat extends ColumnarFileFormat with Logging { compressionExtension + ".orc" } + + override def partitionFlushSize(context: TaskAttemptContext): Long = { + context.getConfiguration.getLong("orc.stripe.size", 64L * 1024L * 1024L) // 64M + } } } } diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala index a32707ce73a..4ea1cbb555e 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala @@ -328,7 +328,7 @@ object GpuFileFormatWriter extends Logging { concurrentOutputWriterSpec match { case Some(spec) => new GpuDynamicPartitionDataConcurrentWriter( - description, taskAttemptContext, committer, spec) + description, taskAttemptContext, committer, spec, TaskContext.get()) case _ => new GpuDynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) } diff --git a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala index 672b7d24be4..71abd5f1cf6 100644 --- a/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala +++ b/sql-plugin/src/main/spark311/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala @@ -39,4 +39,9 @@ object SparkUpgradeExceptionShims { new SparkUpgradeException(version, message, cause) } + // Used in tests to compare the class seen in an exception to + // `SparkUpgradeException` which is private in Spark + def getSparkUpgradeExceptionClass: Class[_] = { + classOf[SparkUpgradeException] + } } diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala index fa2cfbd71bf..ed220f3c56d 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala @@ -36,4 +36,10 @@ object SparkUpgradeExceptionShims { Array(version, message), cause) } + + // Used in tests to compare the class seen in an exception to + // `SparkUpgradeException` which is private in Spark + def getSparkUpgradeExceptionClass: Class[_] = { + classOf[SparkUpgradeException] + } } diff --git a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala index 9f67c31fc3f..2ae55c73057 100644 --- a/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala +++ b/sql-plugin/src/main/spark330db/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala @@ -33,4 +33,10 @@ object SparkUpgradeExceptionShims { Array(version, message), cause) } + + // Used in tests to compare the class seen in an exception to + // `SparkUpgradeException` which is private in Spark + def getSparkUpgradeExceptionClass: Class[_] = { + classOf[SparkUpgradeException] + } } diff --git a/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala b/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala index 5ae4231a6c0..4cbf674d4e4 100644 --- a/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala +++ b/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala @@ -413,7 +413,7 @@ object GpuFileFormatWriter extends Logging { concurrentOutputWriterSpec match { case Some(spec) => new GpuDynamicPartitionDataConcurrentWriter( - description, taskAttemptContext, committer, spec) + description, taskAttemptContext, committer, spec, TaskContext.get()) case _ => new GpuDynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) } diff --git a/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala b/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala index 322073d86fa..a863cd9109e 100644 --- a/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala +++ b/sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/shims/SparkUpgradeExceptionShims.scala @@ -34,4 +34,10 @@ object SparkUpgradeExceptionShims { Map(version -> message), cause) } + + // Used in tests to compare the class seen in an exception to + // `SparkUpgradeException` which is private in Spark + def getSparkUpgradeExceptionClass: Class[_] = { + classOf[SparkUpgradeException] + } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala index 6317ae300e9..d6aa870aa5b 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/CastOpSuite.scala @@ -25,22 +25,16 @@ import java.util.TimeZone import scala.collection.JavaConverters._ import scala.util.{Failure, Random, Success, Try} -import org.scalatest.BeforeAndAfterAll - import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, NamedExpression} import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.rapids.execution.TrampolineUtil import org.apache.spark.sql.types._ -class CastOpSuite extends GpuExpressionTestSuite with BeforeAndAfterAll { +class CastOpSuite extends GpuExpressionTestSuite { import CastOpSuite._ - override def afterAll(): Unit = { - TrampolineUtil.cleanupAnyExistingSession() - } private val sparkConf = new SparkConf() .set(RapidsConf.ENABLE_CAST_FLOAT_TO_INTEGRAL_TYPES.key, "true") diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala index 986601b239c..35d7f990798 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/DeviceMemoryEventHandlerSuite.scala @@ -18,16 +18,15 @@ package com.nvidia.spark.rapids import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.when -import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar -class DeviceMemoryEventHandlerSuite extends AnyFunSuite with MockitoSugar { +class DeviceMemoryEventHandlerSuite extends RmmSparkRetrySuiteBase with MockitoSugar { test("a failed allocation should be retried if we spilled enough") { val mockCatalog = mock[RapidsBufferCatalog] val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSize).thenReturn(1024) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024)) + when(mockStore.currentSpillableSize).thenReturn(1024) + when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024L)) val handler = new DeviceMemoryEventHandler( mockCatalog, mockStore, @@ -40,8 +39,8 @@ class DeviceMemoryEventHandlerSuite extends AnyFunSuite with MockitoSugar { test("when we deplete the store, retry up to max failed OOM retries") { val mockCatalog = mock[RapidsBufferCatalog] val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSize).thenReturn(0) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(0)) + when(mockStore.currentSpillableSize).thenReturn(0) + when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(0L)) val handler = new DeviceMemoryEventHandler( mockCatalog, mockStore, @@ -56,8 +55,8 @@ class DeviceMemoryEventHandlerSuite extends AnyFunSuite with MockitoSugar { test("we reset our OOM state after a successful retry") { val mockCatalog = mock[RapidsBufferCatalog] val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSize).thenReturn(0) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(0)) + when(mockStore.currentSpillableSize).thenReturn(0) + when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(0L)) val handler = new DeviceMemoryEventHandler( mockCatalog, mockStore, @@ -75,8 +74,8 @@ class DeviceMemoryEventHandlerSuite extends AnyFunSuite with MockitoSugar { test("a negative allocation cannot be retried and handler throws") { val mockCatalog = mock[RapidsBufferCatalog] val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSize).thenReturn(1024) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024)) + when(mockStore.currentSpillableSize).thenReturn(1024) + when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024L)) val handler = new DeviceMemoryEventHandler( mockCatalog, mockStore, @@ -89,8 +88,8 @@ class DeviceMemoryEventHandlerSuite extends AnyFunSuite with MockitoSugar { test("a negative retry count is invalid") { val mockCatalog = mock[RapidsBufferCatalog] val mockStore = mock[RapidsDeviceMemoryStore] - when(mockStore.currentSize).thenReturn(1024) - when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024)) + when(mockStore.currentSpillableSize).thenReturn(1024) + when(mockCatalog.synchronousSpill(any(), any(), any())).thenAnswer(_ => Some(1024L)) val handler = new DeviceMemoryEventHandler( mockCatalog, mockStore, diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala index 93cc79f6e06..b1b9de2b4e3 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala @@ -25,10 +25,11 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.parquet.hadoop.ParquetFileReader -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.SparkConf import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.rapids.BasicColumnarWriteJobStatsTracker +import org.apache.spark.sql.rapids.shims.SparkUpgradeExceptionShims /** * Tests for writing Parquet files with the GPU. @@ -116,7 +117,9 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite { } test("set max records per file no partition") { - val conf = new SparkConf().set("spark.sql.files.maxRecordsPerFile", "50") + val conf = new SparkConf() + .set("spark.sql.files.maxRecordsPerFile", "50") + .set(RapidsConf.SQL_ENABLED.key, "true") val tempFile = File.createTempFile("maxRecords", ".parquet") val assertRowCount50 = assertResult(50) _ @@ -139,7 +142,10 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite { } test("set max records per file with partition") { - val conf = new SparkConf().set("spark.sql.files.maxRecordsPerFile", "50") + val conf = new SparkConf() + .set("spark.rapids.sql.batchSizeBytes", "1") // forces multiple batches per partition + .set("spark.sql.files.maxRecordsPerFile", "50") + .set(RapidsConf.SQL_ENABLED.key, "true") val tempFile = File.createTempFile("maxRecords", ".parquet") val assertRowCount50 = assertResult(50) _ @@ -166,8 +172,10 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite { Seq(("40", 40), ("200", 80)).foreach{ case (maxRecordsPerFile, expectedRecordsPerFile) => val conf = new SparkConf() + .set("spark.rapids.sql.batchSizeBytes", "1") // forces multiple batches per partition .set("spark.sql.files.maxRecordsPerFile", maxRecordsPerFile) .set("spark.sql.maxConcurrentOutputFileWriters", "30") + .set(RapidsConf.SQL_ENABLED.key, "true") try { SparkSessionHolder.withSparkSession(conf, spark => { import spark.implicits._ @@ -194,9 +202,44 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite { } } + test("set maxRecordsPerFile with partition concurrently fallback") { + val tempFile = File.createTempFile("maxRecords", ".parquet") + + Seq(("40", 40), ("200", 80)).foreach { case (maxRecordsPerFile, expectedRecordsPerFile) => + val conf = new SparkConf() + .set("spark.rapids.sql.batchSizeBytes", "1") // forces multiple batches per partition + .set("spark.sql.files.maxRecordsPerFile", maxRecordsPerFile) + .set("spark.sql.maxConcurrentOutputFileWriters", "10") + .set(RapidsConf.SQL_ENABLED.key, "true") + try { + SparkSessionHolder.withSparkSession(conf, spark => { + import spark.implicits._ + val df = (1 to 1600).map(i => (i, i % 20)).toDF() + df + .repartition(1) + .write + .mode("overwrite") + .partitionBy("_2") + .parquet(tempFile.getAbsolutePath()) + // check the whole number of rows + assertResult(1600)(spark.read.parquet(tempFile.getAbsolutePath()).count()) + // check number of rows in each file + listAllFiles(tempFile) + .map(f => f.getAbsolutePath()) + .filter(p => p.endsWith("parquet")) + .map(p => { + assertResult(expectedRecordsPerFile)(spark.read.parquet(p).count()) + }) + }) + } finally { + fullyDelete(tempFile) + } + } + } + testExpectedGpuException( "Old dates in EXCEPTION mode", - classOf[SparkException], + SparkUpgradeExceptionShims.getSparkUpgradeExceptionClass, oldDatesDf, new SparkConf().set("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "EXCEPTION")) { val tempFile = File.createTempFile("oldDates", "parquet") @@ -207,9 +250,10 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite { } } + testExpectedGpuException( "Old timestamps millis in EXCEPTION mode", - classOf[SparkException], + SparkUpgradeExceptionShims.getSparkUpgradeExceptionClass, oldTsDf, new SparkConf() .set("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "EXCEPTION") @@ -224,7 +268,7 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite { testExpectedGpuException( "Old timestamps in EXCEPTION mode", - classOf[SparkException], + SparkUpgradeExceptionShims.getSparkUpgradeExceptionClass, oldTsDf, new SparkConf() .set("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "EXCEPTION") diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala index 875531fa446..653f9acda7e 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala @@ -27,6 +27,7 @@ class RmmSparkRetrySuiteBase extends AnyFunSuite with BeforeAndAfterEach { private var rmmWasInitialized = false override def beforeEach(): Unit = { + super.beforeEach() SparkSession.getActiveSession.foreach(_.stop()) SparkSession.clearActiveSession() if (!Rmm.isInitialized) { @@ -43,9 +44,13 @@ class RmmSparkRetrySuiteBase extends AnyFunSuite with BeforeAndAfterEach { } override def afterEach(): Unit = { + super.afterEach() + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() RmmSpark.removeThreadAssociation(RmmSpark.getCurrentThreadId) RmmSpark.clearEventHandler() RapidsBufferCatalog.close() + GpuSemaphore.shutdown() if (rmmWasInitialized) { Rmm.shutdown() } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index 5c4138611ab..d14a3faccce 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -20,7 +20,7 @@ import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.{Locale, TimeZone} -import org.scalatest.Assertion +import org.scalatest.{Assertion, BeforeAndAfterAll} import org.scalatest.funsuite.AnyFunSuite import scala.reflect.ClassTag import scala.util.{Failure, Try} @@ -146,11 +146,16 @@ object SparkSessionHolder extends Logging { /** * Set of tests that compare the output using the CPU version of spark vs our GPU version. */ -trait SparkQueryCompareTestSuite extends AnyFunSuite { +trait SparkQueryCompareTestSuite extends AnyFunSuite with BeforeAndAfterAll { import SparkSessionHolder.withSparkSession def enableCsvConf(): SparkConf = enableCsvConf(new SparkConf()) + override def afterAll(): Unit = { + super.afterAll() + TrampolineUtil.cleanupAnyExistingSession() + } + def enableCsvConf(conf: SparkConf): SparkConf = { conf .set(RapidsConf.ENABLE_READ_CSV_FLOATS.key, "true") @@ -906,9 +911,9 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite { compareResults(sort, maxFloatDiff, fromCpu, fromGpu) } - def testExpectedGpuException[T <: Throwable]( + def testExpectedGpuException( testName: String, - exceptionClass: Class[T], + exceptionClass: Class[_], df: SparkSession => DataFrame, conf: SparkConf = new SparkConf(), repart: Integer = 1, @@ -934,13 +939,24 @@ trait SparkQueryCompareTestSuite extends AnyFunSuite { }, testConf) }) t match { - case Failure(e) if e.getClass == exceptionClass => // Good - case Failure(e) => throw e + case Failure(e) => assertResult(exceptionClass)(getRootCause(e).getClass) case _ => fail("Expected an exception") } } } + private def getRootCause(t: Throwable): Throwable = { + if (t == null) { + t + } else { + var current = t + while (current.getCause != null) { + current = current.getCause + } + current + } + } + def testExpectedException[T <: Throwable]( testName: String, expectedException: T => Boolean, diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala index 1d2379c1ff5..54bb8a7a3cd 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/RapidsShuffleTestHelper.scala @@ -22,14 +22,13 @@ import java.util.concurrent.Executor import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{ColumnVector, ContiguousTable, DeviceMemoryBuffer, HostMemoryBuffer} -import com.nvidia.spark.rapids.{GpuColumnVector, MetaUtils, RapidsBufferHandle, RapidsConf, RapidsDeviceMemoryStore, ShuffleMetadata, ShuffleReceivedBufferCatalog} +import com.nvidia.spark.rapids.{GpuColumnVector, MetaUtils, RapidsBufferHandle, RapidsConf, RapidsDeviceMemoryStore, RmmSparkRetrySuiteBase, ShuffleMetadata, ShuffleReceivedBufferCatalog} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.format.TableMeta import org.mockito.ArgumentCaptor import org.mockito.ArgumentMatchers.any import org.mockito.Mockito.{spy, when} import org.scalatest.BeforeAndAfterEach -import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.sql.rapids.ShuffleMetricsUpdater @@ -52,9 +51,10 @@ class TestShuffleMetricsUpdater extends ShuffleMetricsUpdater { } } -class RapidsShuffleTestHelper extends AnyFunSuite - with BeforeAndAfterEach - with MockitoSugar { +abstract class RapidsShuffleTestHelper + extends RmmSparkRetrySuiteBase + with BeforeAndAfterEach + with MockitoSugar { var mockTransaction: Transaction = _ var mockConnection: MockClientConnection = _ var mockTransport: RapidsShuffleTransport = _ @@ -120,11 +120,13 @@ class RapidsShuffleTestHelper extends AnyFunSuite override def beforeEach(): Unit = { assert(buffersToClose.isEmpty) newMocks() + super.beforeEach() } override def afterEach(): Unit = { buffersToClose.foreach(_.close()) buffersToClose.clear() + super.afterEach() } def newMocks(): Unit = { diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala new file mode 100644 index 00000000000..85884e1e0ff --- /dev/null +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriterSuite.scala @@ -0,0 +1,578 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.rapids + +import ai.rapids.cudf.TableWriter +import com.nvidia.spark.rapids.{ColumnarOutputWriter, ColumnarOutputWriterFactory, GpuBoundReference, GpuColumnVector, RapidsBufferCatalog, RapidsDeviceMemoryStore} +import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.jni.{RetryOOM, SplitAndRetryOOM} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FSDataOutputStream +import org.apache.hadoop.mapred.TaskAttemptContext +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito._ +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite +import org.scalatestplus.mockito.MockitoSugar.mock + +import org.apache.spark.TaskContext +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, ExprId, SortOrder} +import org.apache.spark.sql.execution.datasources.WriteTaskStats +import org.apache.spark.sql.rapids.GpuFileFormatWriter.GpuConcurrentOutputWriterSpec +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} + +class GpuFileFormatDataWriterSuite extends AnyFunSuite with BeforeAndAfterEach { + private var mockJobDescription: GpuWriteJobDescription = _ + private var mockTaskContext: TaskContext = _ + private var mockTaskAttemptContext: TaskAttemptContext = _ + private var mockCommitter: FileCommitProtocol = _ + private var mockOutputWriterFactory: ColumnarOutputWriterFactory = _ + private var mockOutputWriter: NoTransformColumnarOutputWriter = _ + private var devStore: RapidsDeviceMemoryStore = _ + private var allCols: Seq[AttributeReference] = _ + private var partSpec: Seq[AttributeReference] = _ + private var dataSpec: Seq[AttributeReference] = _ + private var includeRetry: Boolean = false + + class NoTransformColumnarOutputWriter( + context: TaskAttemptContext, + dataSchema: StructType, + rangeName: String, + includeRetry: Boolean) + extends ColumnarOutputWriter( + context, + dataSchema, + rangeName, + includeRetry) { + + // this writer (for tests) doesn't do anything and passes through the + // batch passed to it when asked to transform, which is done to + // check for leaks + override def transformAndClose(cb: ColumnarBatch): ColumnarBatch = cb + override val tableWriter: TableWriter = mock[TableWriter] + override def getOutputStream: FSDataOutputStream = mock[FSDataOutputStream] + override def path(): String = null + private var throwOnce: Option[Throwable] = None + override def bufferBatchAndClose(batch: ColumnarBatch): Long = { + //closeOnExcept to maintain the contract of `bufferBatchAndClose` + // we have to close the batch. + closeOnExcept(batch) { _ => + throwOnce.foreach { t => + throwOnce = None + throw t + } + } + super.bufferBatchAndClose(batch) + } + + def throwOnNextBufferBatchAndClose(exception: Throwable): Unit = { + throwOnce = Some(exception) + } + + } + + def mockOutputWriter(types: StructType, includeRetry: Boolean): Unit = { + mockOutputWriter = spy(new NoTransformColumnarOutputWriter( + mockTaskAttemptContext, + types, + "", + includeRetry)) + when(mockOutputWriterFactory.newInstance(any(), any(), any())) + .thenAnswer(_ => mockOutputWriter) + } + + def resetMocks(): Unit = { + allCols = null + partSpec = null + dataSpec = null + mockJobDescription = mock[GpuWriteJobDescription] + when(mockJobDescription.statsTrackers).thenReturn(Seq.empty) + mockTaskContext = mock[TaskContext] + mockTaskAttemptContext = mock[TaskAttemptContext] + mockCommitter = mock[FileCommitProtocol] + mockOutputWriterFactory = mock[ColumnarOutputWriterFactory] + when(mockJobDescription.outputWriterFactory) + .thenAnswer(_ => mockOutputWriterFactory) + } + + def mockEmptyOutputWriter(): Unit = { + resetMocks() + mockOutputWriter(StructType(Seq.empty[StructField]), includeRetry = false) + } + + def resetMocksWithAndWithoutRetry[V](body: => V): Unit = { + Seq(false, true).foreach { retry => + resetMocks() + includeRetry = retry + body + } + } + + /** + * This function takes a seq of GPU-backed `ColumnarBatch` instances and a function body. + * It is used to setup certain mocks before `body` is executed. After execution, the + * columns in the batches are checked for `refCount==0` (e.g. that they were closed). + * @note it is assumed that the schema of each batch is identical. + */ + def withColumnarBatchesVerifyClosed[V](cbs: Seq[ColumnarBatch])(body: => V): Unit = { + val allTypes = cbs.map(GpuColumnVector.extractTypes) + allCols = Seq.empty + dataSpec = Seq.empty + partSpec = Seq.empty + if (allTypes.nonEmpty) { + allCols = allTypes.head.zipWithIndex.map { case (dataType, colIx) => + AttributeReference(s"col_$colIx", dataType, nullable = false)(ExprId(colIx)) + } + partSpec = Seq(allCols.head) + dataSpec = allCols.tail + } + val fields = new Array[StructField](allCols.size) + allCols.zipWithIndex.foreach { case (col, ix) => + fields(ix) = StructField(col.name, col.dataType, nullable = col.nullable) + } + mockOutputWriter(StructType(fields), includeRetry) + if (dataSpec.isEmpty) { + dataSpec = allCols // special case for single column batches + } + when(mockJobDescription.dataColumns).thenReturn(dataSpec) + when(mockJobDescription.partitionColumns).thenReturn(partSpec) + when(mockJobDescription.allColumns).thenReturn(allCols) + try { + body + } finally { + verifyClosed(cbs) + } + } + + override def beforeEach(): Unit = { + devStore = new RapidsDeviceMemoryStore() + val catalog = new RapidsBufferCatalog(devStore) + RapidsBufferCatalog.setCatalog(catalog) + } + + override def afterEach(): Unit = { + // test that no buffers we left in the spill framework + assertResult(0)(RapidsBufferCatalog.numBuffers) + RapidsBufferCatalog.close() + devStore.close() + } + + def buildEmptyBatch: ColumnarBatch = + new ColumnarBatch(Array.empty[ColumnVector], 0) + + def buildBatchWithPartitionedCol(ints: Int*): ColumnarBatch = { + val rowCount = ints.size + val cols: Array[ColumnVector] = new Array[ColumnVector](2) + val partCol = ai.rapids.cudf.ColumnVector.fromInts(ints:_*) + val dataCol = ai.rapids.cudf.ColumnVector.fromStrings(ints.map(_.toString):_*) + cols(0) = GpuColumnVector.from(partCol, IntegerType) + cols(1) = GpuColumnVector.from(dataCol, StringType) + new ColumnarBatch(cols, rowCount) + } + + def verifyClosed(cbs: Seq[ColumnarBatch]): Unit = { + cbs.foreach { cb => + val cols = GpuColumnVector.extractBases(cb) + cols.foreach { col => + assertResult(0)(col.getRefCount) + } + } + } + + def prepareDynamicPartitionSingleWriter(): + GpuDynamicPartitionDataSingleWriter = { + when(mockJobDescription.bucketSpec).thenReturn(None) + when(mockJobDescription.customPartitionLocations) + .thenReturn(Map.empty[TablePartitionSpec, String]) + + spy(new GpuDynamicPartitionDataSingleWriter( + mockJobDescription, + mockTaskAttemptContext, + mockCommitter)) + } + + def prepareDynamicPartitionConcurrentWriter(maxWriters: Int, batchSize: Long): + GpuDynamicPartitionDataConcurrentWriter = { + val mockConfig = new Configuration() + when(mockTaskAttemptContext.getConfiguration).thenReturn(mockConfig) + when(mockJobDescription.bucketSpec).thenReturn(None) + when(mockJobDescription.customPartitionLocations) + .thenReturn(Map.empty[TablePartitionSpec, String]) + // assume the first column is the partition-by column + val sortExpr = + GpuBoundReference(0, partSpec.head.dataType, nullable = false)(ExprId(0), "") + val sortSpec = Seq(SortOrder(sortExpr, Ascending)) + val concurrentSpec = GpuConcurrentOutputWriterSpec( + maxWriters, allCols, batchSize, sortSpec) + + spy(new GpuDynamicPartitionDataConcurrentWriter( + mockJobDescription, + mockTaskAttemptContext, + mockCommitter, + concurrentSpec, + mockTaskContext)) + } + + test("empty directory data writer") { + mockEmptyOutputWriter() + val emptyWriter = spy(new GpuEmptyDirectoryDataWriter( + mockJobDescription, mockTaskAttemptContext, mockCommitter)) + emptyWriter.writeWithIterator(Iterator.empty) + emptyWriter.commit() + verify(emptyWriter, times(0)) + .write(any[ColumnarBatch]) + verify(emptyWriter, times(1)) + .releaseResources() + } + + test("empty directory data writer with non-empty iterator") { + mockEmptyOutputWriter() + // this should never be the case, as the empty directory writer + // is only instantiated when the iterator is empty. Adding it + // because the expected behavior is to fully consume the iterator + // and close all the empty batches. + val emptyWriter = spy(new GpuEmptyDirectoryDataWriter( + mockJobDescription, mockTaskAttemptContext, mockCommitter)) + val cbs = Seq( + spy(buildEmptyBatch), + spy(buildEmptyBatch)) + emptyWriter.writeWithIterator(cbs.iterator) + emptyWriter.commit() + verify(emptyWriter, times(2)) + .write(any[ColumnarBatch]) + verify(emptyWriter, times(1)) + .releaseResources() + cbs.foreach { cb => verify(cb, times(1)).close()} + } + + test("single directory data writer with empty iterator") { + resetMocksWithAndWithoutRetry { + // build a batch just so that the test code can infer the schema + val cbs = Seq(buildBatchWithPartitionedCol(1)) + withColumnarBatchesVerifyClosed(cbs) { + withResource(cbs) { _ => + val singleWriter = spy(new GpuSingleDirectoryDataWriter( + mockJobDescription, mockTaskAttemptContext, mockCommitter)) + singleWriter.writeWithIterator(Iterator.empty) + singleWriter.commit() + } + } + } + } + + test("single directory data writer") { + resetMocksWithAndWithoutRetry { + val cb = buildBatchWithPartitionedCol(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + val cb2 = buildBatchWithPartitionedCol(1, 2, 3, 4, 5) + val cbs = Seq(spy(cb), spy(cb2)) + withColumnarBatchesVerifyClosed(cbs) { + val singleWriter = spy(new GpuSingleDirectoryDataWriter( + mockJobDescription, mockTaskAttemptContext, mockCommitter)) + singleWriter.writeWithIterator(cbs.iterator) + singleWriter.commit() + // we write 2 batches + verify(mockOutputWriter, times(2)) + .writeSpillableAndClose(any(), any()) + verify(mockOutputWriter, times(1)).close() + } + } + } + + test("single directory data writer with splits") { + resetMocksWithAndWithoutRetry { + val cb = buildBatchWithPartitionedCol(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + val cb2 = buildBatchWithPartitionedCol(1, 2, 3, 4, 5) + val cbs = Seq(spy(cb), spy(cb2)) + withColumnarBatchesVerifyClosed(cbs) { + // setting this to 5 makes the single writer have to split at the 5 row boundary + when(mockJobDescription.maxRecordsPerFile).thenReturn(5) + val singleWriter = spy(new GpuSingleDirectoryDataWriter( + mockJobDescription, mockTaskAttemptContext, mockCommitter)) + + singleWriter.writeWithIterator(cbs.iterator) + singleWriter.commit() + // twice for the first batch given the split, and once for the second batch + verify(mockOutputWriter, times(3)) + .writeSpillableAndClose(any(), any()) + // three because we wrote 3 files (15 rows, limit was 5 rows per file) + verify(mockOutputWriter, times(3)).close() + } + } + } + + test("dynamic partition data writer doesn't support bucketing") { + resetMocksWithAndWithoutRetry { + withColumnarBatchesVerifyClosed(Seq.empty) { + when(mockJobDescription.bucketSpec).thenReturn(Some(GpuWriterBucketSpec(null, null))) + assertThrows[UnsupportedOperationException] { + new GpuDynamicPartitionDataSingleWriter( + mockJobDescription, mockTaskAttemptContext, mockCommitter) + } + } + } + } + + test("dynamic partition data writer without splits") { + resetMocksWithAndWithoutRetry { + // 4 partitions + val cb = buildBatchWithPartitionedCol(1, 1, 2, 2, 3, 3, 4, 4) + // 5 partitions + val cb2 = buildBatchWithPartitionedCol(1, 2, 3, 4, 5) + val cbs = Seq(spy(cb), spy(cb2)) + withColumnarBatchesVerifyClosed(cbs) { + // setting this to 3 => the writer won't split as no partition has more than 3 rows + when(mockJobDescription.maxRecordsPerFile).thenReturn(3) + val dynamicSingleWriter = prepareDynamicPartitionSingleWriter() + dynamicSingleWriter.writeWithIterator(cbs.iterator) + dynamicSingleWriter.commit() + // we write 9 batches (4 partitions in the first bach, and 5 partitions in the second) + verify(mockOutputWriter, times(9)) + .writeSpillableAndClose(any(), any()) + verify(dynamicSingleWriter, times(9)).newWriter(any(), any(), any()) + // it uses 9 writers because the single writer mode only keeps one writer open at a time + // and once a new partition is seen, the old writer is closed and a new one is opened. + verify(mockOutputWriter, times(9)).close() + } + } + } + + test("dynamic partition data writer with splits") { + resetMocksWithAndWithoutRetry { + val cb = buildBatchWithPartitionedCol(1, 1, 2, 2, 3, 3, 4, 4) + val cb2 = buildBatchWithPartitionedCol(1, 2, 3, 4, 5) + val cbs = Seq(spy(cb), spy(cb2)) + withColumnarBatchesVerifyClosed(cbs) { + // force 1 row batches to be written + when(mockJobDescription.maxRecordsPerFile).thenReturn(1) + val dynamicSingleWriter = prepareDynamicPartitionSingleWriter() + dynamicSingleWriter.writeWithIterator(cbs.iterator) + dynamicSingleWriter.commit() + // we get 13 calls because we write 13 individual batches after splitting + verify(mockOutputWriter, times(13)) + .writeSpillableAndClose(any(), any()) + verify(dynamicSingleWriter, times(13)).newWriter(any(), any(), any()) + // since we have a limit of 1 record per file, we write 13 files + verify(mockOutputWriter, times(13)) + .close() + } + } + } + + test("dynamic partition concurrent data writer with splits") { + resetMocksWithAndWithoutRetry { + // 4 partitions + val cb = buildBatchWithPartitionedCol(1, 1, 2, 2, 3, 3, 4, 4) + // 5 partitions + val cb2 = buildBatchWithPartitionedCol(1, 2, 3, 4, 5) + val cbs = Seq(spy(cb), spy(cb2)) + withColumnarBatchesVerifyClosed(cbs) { + when(mockJobDescription.maxRecordsPerFile).thenReturn(3) + val dynamicConcurrentWriter = + prepareDynamicPartitionConcurrentWriter(maxWriters = 9, batchSize = 1) + dynamicConcurrentWriter.writeWithIterator(cbs.iterator) + dynamicConcurrentWriter.commit() + // we get 9 calls because we have 9 partitions total + verify(mockOutputWriter, times(9)) + .writeSpillableAndClose(any(), any()) + // we write 5 files because we write 1 file per partition, since this concurrent + // writer was able to keep the writers alive + verify(dynamicConcurrentWriter, times(5)).newWriter(any(), any(), any()) + verify(mockOutputWriter, times(5)).close() + } + } + } + + test("dynamic partition concurrent data writer with splits and flush") { + resetMocksWithAndWithoutRetry { + val cb = buildBatchWithPartitionedCol(1, 1, 2, 2, 3, 3, 4, 4) + val cb2 = buildBatchWithPartitionedCol(1, 2, 3, 4, 5) + val cbs = Seq(spy(cb), spy(cb2)) + withColumnarBatchesVerifyClosed(cbs) { + // I would like to not flush on the first iteration of the `write` method + when(mockJobDescription.concurrentWriterPartitionFlushSize).thenReturn(1000) + when(mockJobDescription.maxRecordsPerFile).thenReturn(1) + val dynamicConcurrentWriter = + prepareDynamicPartitionConcurrentWriter(maxWriters = 9, batchSize = 1) + dynamicConcurrentWriter.writeWithIterator(cbs.iterator) + dynamicConcurrentWriter.commit() + + // we get 13 calls here because we write 1 row files + verify(mockOutputWriter, times(13)) + .writeSpillableAndClose(any(), any()) + verify(dynamicConcurrentWriter, times(13)).newWriter(any(), any(), any()) + + // we have to open 13 writers (1 per row) given the record limit of 1 + verify(mockOutputWriter, times(13)).close() + } + } + } + + test("dynamic partition concurrent data writer fallback without splits") { + resetMocksWithAndWithoutRetry { + val cb = buildBatchWithPartitionedCol(1, 1, 2, 2, 3, 3, 4, 4) + val cb2 = buildBatchWithPartitionedCol(1, 2, 3, 4, 5) + val cbs = Seq(spy(cb), spy(cb2)) + withColumnarBatchesVerifyClosed(cbs) { + // no splitting will occur because the partitions have 3 or less rows. + when(mockJobDescription.maxRecordsPerFile).thenReturn(3) + // because `maxWriters==1` we will fallback right away and + // behave like the dynamic single writer + val dynamicConcurrentWriter = + prepareDynamicPartitionConcurrentWriter(maxWriters = 1, batchSize = 1) + dynamicConcurrentWriter.writeWithIterator(cbs.iterator) + dynamicConcurrentWriter.commit() + // 5 batches written, one per partition (no splitting) + verify(mockOutputWriter, times(5)) + .writeSpillableAndClose(any(), any()) + verify(dynamicConcurrentWriter, times(5)).newWriter(any(), any(), any()) + // 5 files written because this is the single writer mode + verify(mockOutputWriter, times(5)).close() + } + } + } + + test("dynamic partition concurrent data writer fallback with splits") { + resetMocksWithAndWithoutRetry { + val cb = buildBatchWithPartitionedCol(1, 1, 1, 2, 2, 3, 3, 4, 4) + val cb2 = buildBatchWithPartitionedCol(1, 2, 3, 4) + val cb3 = buildBatchWithPartitionedCol(1, 2, 3, 4, 5) // fallback here (5 writers) + val cbs = Seq(spy(cb), spy(cb2), spy(cb3)) + withColumnarBatchesVerifyClosed(cbs) { + // I would like to not flush on the first iteration of the `write` method + when(mockJobDescription.concurrentWriterPartitionFlushSize).thenReturn(1000) + when(mockJobDescription.maxRecordsPerFile).thenReturn(1) + val dynamicConcurrentWriter = + prepareDynamicPartitionConcurrentWriter(maxWriters = 5, batchSize = 1) + dynamicConcurrentWriter.writeWithIterator(cbs.iterator) + dynamicConcurrentWriter.commit() + // 18 batches are written, once per row above given maxRecorsPerFile + verify(mockOutputWriter, times(18)) + .writeSpillableAndClose(any(), any()) + verify(dynamicConcurrentWriter, times(18)).newWriter(any(), any(), any()) + // dynamic partitioning code calls close several times on the same ColumnarOutputWriter + // that doesn't seem to be an issue right now, but verifying that the writer was closed + // is not as clean here, especially during a fallback like in this test. + // A follow on issue is filed to handle this better: + // https://github.com/NVIDIA/spark-rapids/issues/8736 + // verify(mockOutputWriter, times(18)).close() + } + } + } + + test("call newBatch only once when there is a failure writing") { + // this test is to exercise the contract that the ColumnarWriteTaskStatsTracker.newBatch + // has. When there is a retry within writeSpillableAndClose, we will guarantee that + // newBatch will be called only once. If there are exceptions within newBatch they are fatal, + // and are not retried. + resetMocksWithAndWithoutRetry { + val cb = buildBatchWithPartitionedCol(1, 1, 1, 1, 1, 1, 1, 1, 1) + val cbs = Seq(spy(cb)) + withColumnarBatchesVerifyClosed(cbs) { + // I would like to not flush on the first iteration of the `write` method + when(mockJobDescription.concurrentWriterPartitionFlushSize).thenReturn(1000) + when(mockJobDescription.maxRecordsPerFile).thenReturn(9) + + val statsTracker = mock[ColumnarWriteTaskStatsTracker] + val jobTracker = new ColumnarWriteJobStatsTracker { + override def newTaskInstance(): ColumnarWriteTaskStatsTracker = { + statsTracker + } + override def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long): Unit = {} + } + when(mockJobDescription.statsTrackers) + .thenReturn(Seq(jobTracker)) + + // throw once from bufferBatchAndClose to simulate an exception after we call the + // stats tracker + mockOutputWriter.throwOnNextBufferBatchAndClose( + new SplitAndRetryOOM("mocking a split and retry")) + val dynamicConcurrentWriter = + prepareDynamicPartitionConcurrentWriter(maxWriters = 5, batchSize = 1) + + if (includeRetry) { + dynamicConcurrentWriter.writeWithIterator(cbs.iterator) + dynamicConcurrentWriter.commit() + } else { + assertThrows[SplitAndRetryOOM] { + dynamicConcurrentWriter.writeWithIterator(cbs.iterator) + dynamicConcurrentWriter.commit() + } + } + + // 1 batch is written, all rows fit + verify(mockOutputWriter, times(1)) + .writeSpillableAndClose(any(), any()) + // we call newBatch once + verify(statsTracker, times(1)).newBatch(any(), any()) + if (includeRetry) { + // we call it 3 times, once for the first whole batch that fails with OOM + // and twice for the two halves after we handle the OOM + verify(mockOutputWriter, times(3)).bufferBatchAndClose(any()) + } else { + // once and we fail, so we don't retry + verify(mockOutputWriter, times(1)).bufferBatchAndClose(any()) + } + } + } + } + + test("newBatch throwing is fatal") { + // this test is to exercise the contract that the ColumnarWriteTaskStatsTracker.newBatch + // has. When there is a retry within writeSpillableAndClose, we will guarantee that + // newBatch will be called only once. If there are exceptions within newBatch they are fatal, + // and are not retried. + resetMocksWithAndWithoutRetry { + val cb = buildBatchWithPartitionedCol(1, 1, 1, 1, 1, 1, 1, 1, 1) + val cbs = Seq(spy(cb)) + withColumnarBatchesVerifyClosed(cbs) { + // I would like to not flush on the first iteration of the `write` method + when(mockJobDescription.concurrentWriterPartitionFlushSize).thenReturn(1000) + when(mockJobDescription.maxRecordsPerFile).thenReturn(9) + + val statsTracker = mock[ColumnarWriteTaskStatsTracker] + val jobTracker = new ColumnarWriteJobStatsTracker { + override def newTaskInstance(): ColumnarWriteTaskStatsTracker = { + statsTracker + } + + override def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long): Unit = {} + } + when(mockJobDescription.statsTrackers) + .thenReturn(Seq(jobTracker)) + when(statsTracker.newBatch(any(), any())) + .thenThrow(new RetryOOM("mocking a retry")) + val dynamicConcurrentWriter = + prepareDynamicPartitionConcurrentWriter(maxWriters = 5, batchSize = 1) + + assertThrows[RetryOOM] { + dynamicConcurrentWriter.writeWithIterator(cbs.iterator) + dynamicConcurrentWriter.commit() + } + + // we never reach the buffer stage + verify(mockOutputWriter, times(0)).bufferBatchAndClose(any()) + // we attempt to write one batch + verify(mockOutputWriter, times(1)) + .writeSpillableAndClose(any(), any()) + // we call newBatch once + verify(statsTracker, times(1)).newBatch(any(), any()) + } + } + } +} diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/filecache/FileCacheIntegrationSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/filecache/FileCacheIntegrationSuite.scala index 56990df2ba1..b3cf0c04d35 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/filecache/FileCacheIntegrationSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/filecache/FileCacheIntegrationSuite.scala @@ -16,14 +16,15 @@ package org.apache.spark.sql.rapids.filecache -import com.nvidia.spark.rapids.SparkQueryCompareTestSuite +import com.nvidia.spark.rapids.{RapidsBufferCatalog, RapidsDeviceMemoryStore, SparkQueryCompareTestSuite} import com.nvidia.spark.rapids.shims.GpuBatchScanExec +import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkConf import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.rapids.GpuFileSourceScanExec -class FileCacheIntegrationSuite extends SparkQueryCompareTestSuite { +class FileCacheIntegrationSuite extends SparkQueryCompareTestSuite with BeforeAndAfterEach { import com.nvidia.spark.rapids.GpuMetric._ private val FILE_SPLITS_PARQUET = "file-splits.parquet" @@ -31,6 +32,17 @@ class FileCacheIntegrationSuite extends SparkQueryCompareTestSuite { private val MAP_OF_STRINGS_PARQUET = "map_of_strings.snappy.parquet" private val SCHEMA_CANT_PRUNE_ORC = "schema-cant-prune.orc" + override def beforeEach(): Unit = { + val deviceStorage = new RapidsDeviceMemoryStore() + val catalog = new RapidsBufferCatalog(deviceStorage) + RapidsBufferCatalog.setDeviceStorage(deviceStorage) + RapidsBufferCatalog.setCatalog(catalog) + } + + override def afterEach(): Unit = { + RapidsBufferCatalog.close() + } + def isFileCacheEnabled(conf: SparkConf): Boolean = { // File cache only supported on Spark 3.2+ assumeSpark320orLater