From 6c50e8d5f9a2e8aa2c84218b28510652229b4d55 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 8 Aug 2023 08:23:57 -0500 Subject: [PATCH] Device syncronize prior to freeing a set of RapidsBuffer (#8936) * Device syncronize prior to freeing a set of RapidsBuffer Signed-off-by: Alessandro Bellina Co-authored-by: Jason Lowe --------- Signed-off-by: Alessandro Bellina Co-authored-by: Jason Lowe --- .../spark/rapids/RapidsBufferCatalog.scala | 66 +++++++++++++------ 1 file changed, 45 insertions(+), 21 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index 724b71fe631..8444daf790d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -19,6 +19,8 @@ package com.nvidia.spark.rapids import java.util.concurrent.ConcurrentHashMap import java.util.function.BiFunction +import scala.collection.mutable.ArrayBuffer + import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, NvtxColor, NvtxRange, Rmm, Table} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsBufferCatalog.getExistingRapidsBufferAndAcquire @@ -518,29 +520,47 @@ class RapidsBufferCatalog( // If the store has 0 spillable bytes left, it has exhausted. var exhausted = false - while (!exhausted && !rmmShouldRetryAlloc && + val buffersToFree = new ArrayBuffer[RapidsBuffer]() + try { + while (!exhausted && !rmmShouldRetryAlloc && store.currentSpillableSize > targetTotalSize) { - val mySpillCount = spillCount - synchronized { - if (spillCount == mySpillCount) { - spillCount += 1 - val nextSpillable = store.nextSpillable() - if (nextSpillable != null) { - // we have a buffer (nextSpillable) to spill - spillAndFreeBuffer(nextSpillable, spillStore, stream) - totalSpilled += nextSpillable.getMemoryUsedBytes + val mySpillCount = spillCount + synchronized { + if (spillCount == mySpillCount) { + spillCount += 1 + val nextSpillable = store.nextSpillable() + if (nextSpillable != null) { + // we have a buffer (nextSpillable) to spill + // spill it and store it in `buffersToFree` to + // free all in one go after a synchronize. + spillBuffer(nextSpillable, spillStore, stream) + .foreach(buffersToFree.append(_)) + totalSpilled += nextSpillable.getMemoryUsedBytes + } + } else { + rmmShouldRetryAlloc = true } - } else { - rmmShouldRetryAlloc = true } - } - if (!rmmShouldRetryAlloc && totalSpilled <= 0) { - // we didn't spill in this iteration, exit loop - exhausted = true - logWarning("Unable to spill enough to meet request. " + + if (!rmmShouldRetryAlloc && totalSpilled <= 0) { + // we didn't spill in this iteration, exit loop + exhausted = true + logWarning("Unable to spill enough to meet request. " + s"Total=${store.currentSize} " + s"Spillable=${store.currentSpillableSize} " + s"Target=$targetTotalSize") + } + } + } finally { + if (buffersToFree.nonEmpty) { + // This is a hack in order to completely synchronize with the GPU before we free + // a buffer. It is necessary because of non-synchronous cuDF calls that could fall + // behind where the CPU is. Freeing a rapids buffer in these cases needs to wait for + // all launched GPU work, otherwise crashes or data corruption could occur. + // A more performant implementation would be to synchronize on the thread that read + // the buffer via events. + // https://github.com/NVIDIA/spark-rapids/issues/8610 + Cuda.deviceSynchronize() + buffersToFree.safeFree() } } } @@ -557,12 +577,13 @@ class RapidsBufferCatalog( /** * Given a specific `RapidsBuffer` spill it to `spillStore` + * @return the buffer, if successfully spilled, in order for the caller to free it * @note called with catalog lock held */ - private def spillAndFreeBuffer( + private def spillBuffer( buffer: RapidsBuffer, spillStore: RapidsBufferStore, - stream: Cuda.Stream): Unit = { + stream: Cuda.Stream): Option[RapidsBuffer] = { if (buffer.addReference()) { withResource(buffer) { _ => logDebug(s"Spilling $buffer ${buffer.id} to ${spillStore.name}") @@ -584,8 +605,11 @@ class RapidsBufferCatalog( } // we can now remove the old tier linkage removeBufferTier(buffer.id, buffer.storageTier) - // and free - buffer.safeFree() + + // return the buffer + Some(buffer) + } else { + None } }