Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[PMEM-SHUFFLE-46] Fix the bug that off-heap memory is over used in sh…
Browse files Browse the repository at this point in the history
…uffle reduce stage
  • Loading branch information
Eugene-Mark committed Aug 20, 2021
1 parent 5e8ab2d commit 7f18715
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private[spark] class BaseShuffleReader[K, C](handle: BaseShuffleHandle[K, _, C],
/**
* Force iterator to traverse itself and update internal counter
**/
wrappedStreams.size
//wrappedStreams.size

val serializerInstance = dep.serializer.newInstance()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ private[spark] class PmofShuffleManager(conf: SparkConf) extends ShuffleManager
override def getReader[K, C](handle: _root_.org.apache.spark.shuffle.ShuffleHandle, startMapIndex: Int, endMapIndex: Int, startPartition: Int, endPartition: Int, context: _root_.org.apache.spark.TaskContext, readMetrics: ShuffleReadMetricsReporter): _root_.org.apache.spark.shuffle.ShuffleReader[K, C] = {
val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
val env: SparkEnv = SparkEnv.get
if (pmofConf.enableRemotePmem) {
new RpmpShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
Expand All @@ -73,6 +72,7 @@ private[spark] class PmofShuffleManager(conf: SparkConf) extends ShuffleManager
context,
pmofConf)
} else if (pmofConf.enableRdma) {
val env: SparkEnv = SparkEnv.get
metadataResolver = MetadataResolver.getMetadataResolver(pmofConf)
PmofTransferService.getTransferServiceInstance(pmofConf, env.blockManager, this)
new RdmaShuffleReader(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ trait PmemBlockInputStream[K, C] {
}

class LocalPmemBlockInputStream[K, C](
blockId: BlockId,
total_records: Long,
pmemBlockOutputStream: PmemBlockOutputStream,
serializer: Serializer)
extends PmemBlockInputStream[K, C] {
val blockId: BlockId = pmemBlockOutputStream.getBlockId()
val serializerManager: SerializerManager = SparkEnv.get.serializerManager
val serInstance: SerializerInstance = serializer.newInstance()
val persistentMemoryWriter: PersistentMemoryHandler =
Expand All @@ -27,7 +27,9 @@ class LocalPmemBlockInputStream[K, C](
var inObjStream: DeserializationStream = serInstance.deserializeStream(wrappedStream)

var indexInBatch: Int = 0
var total_records: Long = 0
var closing: Boolean = false
total_records = pmemBlockOutputStream.getTotalRecords()

def readNextItem(): (K, C) = {
if (closing == true) {
Expand Down Expand Up @@ -126,4 +128,4 @@ class RemotePmemBlockInputStream[K, C](
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ private[spark] class PmemBlockOutputStream(
if ((pmofConf.spill_throttle != -1 && pmemOutputStream.bufferRemainingSize >= pmofConf.spill_throttle) || force == true) {
val start = System.nanoTime()
flush()
//pmemOutputStream.doFlush()
pmemOutputStream.doFlush()
val bufSize = pmemOutputStream.flushedSize
mapStatus += ((pmemOutputStream.flushed_block_id, bufSize, recordsPerBlock))
if (bufSize > 0) {
recordsArray += recordsPerBlock
recordsPerBlock = 0
size += bufSize
size = bufSize

if (blockId.isShuffle == true) {
val writeMetrics = taskMetrics.shuffleWriteMetrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PmemOutputStream(
var is_closed = false
var key_id = 0

val length: Int = 1024 * 1024 * 6
val length: Int = bufferSize
var bufferFlushedSize: Int = 0
var bufferRemainingSize: Int = 0
val buf: ByteBuf = NettyByteBufferPool.allocateFlexibleNewBuffer(length)
Expand All @@ -37,6 +37,10 @@ class PmemOutputStream(
}

override def flush(): Unit = {

}

def doFlush(): Unit = {
if (bufferRemainingSize > 0) {
if (remotePersistentMemoryPool != null) {
logDebug(s" [PUT Started]${cur_block_id}-${bufferRemainingSize}")
Expand Down Expand Up @@ -73,10 +77,6 @@ class PmemOutputStream(
}
}

def doFlush(): Unit = {

}

def flushedSize(): Int = {
bufferFlushedSize
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,7 @@ private[spark] class PmemExternalSorter[K, V, C](
// which is different from spark original codes (relate to one spill file)
val pmemBlockInputStream = if (!pmofConf.enableRemotePmem) {
new LocalPmemBlockInputStream[K, C](
pmemBlockOutputStream.getBlockId,
pmemBlockOutputStream.getTotalRecords,
pmemBlockOutputStream,
serializer)
} else {
new RemotePmemBlockInputStream[K, C](
Expand Down

0 comments on commit 7f18715

Please sign in to comment.