Skip to content

Commit

Permalink
Device syncronize prior to freeing a set of RapidsBuffer (NVIDIA#8936)
Browse files Browse the repository at this point in the history
* Device syncronize prior to freeing a set of RapidsBuffer

Signed-off-by: Alessandro Bellina <[email protected]>
Co-authored-by: Jason Lowe <[email protected]>

---------

Signed-off-by: Alessandro Bellina <[email protected]>
Co-authored-by: Jason Lowe <[email protected]>
  • Loading branch information
abellina and jlowe authored Aug 8, 2023
1 parent 7d2ce0f commit 6c50e8d
Showing 1 changed file with 45 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package com.nvidia.spark.rapids
import java.util.concurrent.ConcurrentHashMap
import java.util.function.BiFunction

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.{ContiguousTable, Cuda, DeviceMemoryBuffer, NvtxColor, NvtxRange, Rmm, Table}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.RapidsBufferCatalog.getExistingRapidsBufferAndAcquire
Expand Down Expand Up @@ -518,29 +520,47 @@ class RapidsBufferCatalog(
// If the store has 0 spillable bytes left, it has exhausted.
var exhausted = false

while (!exhausted && !rmmShouldRetryAlloc &&
val buffersToFree = new ArrayBuffer[RapidsBuffer]()
try {
while (!exhausted && !rmmShouldRetryAlloc &&
store.currentSpillableSize > targetTotalSize) {
val mySpillCount = spillCount
synchronized {
if (spillCount == mySpillCount) {
spillCount += 1
val nextSpillable = store.nextSpillable()
if (nextSpillable != null) {
// we have a buffer (nextSpillable) to spill
spillAndFreeBuffer(nextSpillable, spillStore, stream)
totalSpilled += nextSpillable.getMemoryUsedBytes
val mySpillCount = spillCount
synchronized {
if (spillCount == mySpillCount) {
spillCount += 1
val nextSpillable = store.nextSpillable()
if (nextSpillable != null) {
// we have a buffer (nextSpillable) to spill
// spill it and store it in `buffersToFree` to
// free all in one go after a synchronize.
spillBuffer(nextSpillable, spillStore, stream)
.foreach(buffersToFree.append(_))
totalSpilled += nextSpillable.getMemoryUsedBytes
}
} else {
rmmShouldRetryAlloc = true
}
} else {
rmmShouldRetryAlloc = true
}
}
if (!rmmShouldRetryAlloc && totalSpilled <= 0) {
// we didn't spill in this iteration, exit loop
exhausted = true
logWarning("Unable to spill enough to meet request. " +
if (!rmmShouldRetryAlloc && totalSpilled <= 0) {
// we didn't spill in this iteration, exit loop
exhausted = true
logWarning("Unable to spill enough to meet request. " +
s"Total=${store.currentSize} " +
s"Spillable=${store.currentSpillableSize} " +
s"Target=$targetTotalSize")
}
}
} finally {
if (buffersToFree.nonEmpty) {
// This is a hack in order to completely synchronize with the GPU before we free
// a buffer. It is necessary because of non-synchronous cuDF calls that could fall
// behind where the CPU is. Freeing a rapids buffer in these cases needs to wait for
// all launched GPU work, otherwise crashes or data corruption could occur.
// A more performant implementation would be to synchronize on the thread that read
// the buffer via events.
// https://github.com/NVIDIA/spark-rapids/issues/8610
Cuda.deviceSynchronize()
buffersToFree.safeFree()
}
}
}
Expand All @@ -557,12 +577,13 @@ class RapidsBufferCatalog(

/**
* Given a specific `RapidsBuffer` spill it to `spillStore`
* @return the buffer, if successfully spilled, in order for the caller to free it
* @note called with catalog lock held
*/
private def spillAndFreeBuffer(
private def spillBuffer(
buffer: RapidsBuffer,
spillStore: RapidsBufferStore,
stream: Cuda.Stream): Unit = {
stream: Cuda.Stream): Option[RapidsBuffer] = {
if (buffer.addReference()) {
withResource(buffer) { _ =>
logDebug(s"Spilling $buffer ${buffer.id} to ${spillStore.name}")
Expand All @@ -584,8 +605,11 @@ class RapidsBufferCatalog(
}
// we can now remove the old tier linkage
removeBufferTier(buffer.id, buffer.storageTier)
// and free
buffer.safeFree()

// return the buffer
Some(buffer)
} else {
None
}
}

Expand Down

0 comments on commit 6c50e8d

Please sign in to comment.