diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReaderForUnsafeExternalSorter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReaderForUnsafeExternalSorter.java index f970b0dd..4a39648b 100644 --- a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReaderForUnsafeExternalSorter.java +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReaderForUnsafeExternalSorter.java @@ -35,7 +35,7 @@ public PMemReaderForUnsafeExternalSorter( this.numRecordsRemaining = numRecords - position/2; this.taskMetrics = taskMetrics; int readBufferSize = SparkEnv.get() == null? 8 * 1024 * 1024 : - (int) SparkEnv.get().conf().get(package$.MODULE$.MEMORY_SPILL_PMEM_READ_BUFFERSIZE()); + (int) (long) SparkEnv.get().conf().get(package$.MODULE$.MEMORY_SPILL_PMEM_READ_BUFFERSIZE()); logger.info("PMem read buffer size is:" + Utils.bytesToString(readBufferSize)); this.byteBuffer = ByteBuffer.wrap(new byte[readBufferSize]); byteBuffer.flip(); diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemWriter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemWriter.java index 2db27343..906ab506 100644 --- a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemWriter.java +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemWriter.java @@ -95,7 +95,7 @@ public void write() throws IOException { long writeDuration = 0; ExecutorService executorService = Executors.newSingleThreadExecutor(); Future future = executorService.submit(()->dumpPagesToPMem()); - externalSorter.getInMemSorter().getSortedIterator(); + inMemSorter.getSortedIterator(); try { writeDuration = future.get(); } catch (InterruptedException | ExecutionException e) { @@ -106,7 +106,7 @@ public void write() throws IOException { } else if(!isSorted) { dumpPagesToPMem(); // get sorted iterator - externalSorter.getInMemSorter().getSortedIterator(); + inMemSorter.getSortedIterator(); // update LongArray updateLongArray(inMemSorter.getArray(), totalRecordsWritten, 0); } else { @@ -121,7 +121,7 @@ public void write() throws IOException { diskSpillWriter = new UnsafeSorterSpillWriter( blockManager, fileBufferSize, - sortedIterator, + isSorted? sortedIterator : inMemSorter.getSortedIterator(), numberOfRecordsToWritten, serializerManager, writeMetrics, @@ -138,6 +138,7 @@ public boolean allocatePMemPages(LinkedList dramPages, MemoryBlock allocatedPMemPages.add(pMemBlock); pageMap.put(page, pMemBlock); } else { + freeAllPMemPages(); pageMap.clear(); return false; } @@ -147,6 +148,7 @@ public boolean allocatePMemPages(LinkedList dramPages, MemoryBlock allocatedPMemPages.add(pMemPageForLongArray); pageMap.put(longArrayPage, pMemPageForLongArray); } else { + freeAllPMemPages(); pageMap.clear(); return false; } diff --git a/src/main/scala/org/apache/spark/SparkEnv.scala b/src/main/scala/org/apache/spark/SparkEnv.scala index 8ba17398..da83f262 100644 --- a/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/src/main/scala/org/apache/spark/SparkEnv.scala @@ -27,6 +27,8 @@ import scala.collection.mutable import scala.util.Properties import com.google.common.cache.CacheBuilder +import com.intel.oap.common.unsafe.PersistentMemoryPlatform + import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.DeveloperApi @@ -244,6 +246,14 @@ object SparkEnv extends Logging { val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER + val pMemEnabled = conf.get(MEMORY_SPILL_PMEM_ENABLED); + if (pMemEnabled && !isDriver) { + val pMemInitialPath = conf.get(MEMORY_EXTENDED_PATH) + val pMemInitialSize = conf.get(MEMORY_EXTENDED_SIZE) + PersistentMemoryPlatform.initialize(pMemInitialPath, pMemInitialSize, 0) + logInfo(s"PMem initialize path: ${pMemInitialPath}, size: ${pMemInitialSize} ") + } + // Listener bus is only used on the driver if (isDriver) { assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!") diff --git a/src/main/scala/org/apache/spark/storage/BlockManager.scala b/src/main/scala/org/apache/spark/storage/BlockManager.scala index ee2ef7e7..5b3379c8 100644 --- a/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -189,7 +189,7 @@ private[spark] class BlockManager( val pmemMode = conf.get("spark.memory.pmem.mode", "AppDirect") val numNum = conf.getInt("spark.yarn.numa.num", 2) - if (pmemMode.equals("AppDirect")) { + if (memExtensionEnabled && pmemMode.equals("AppDirect")) { if (!isDriver && pmemInitialPaths.size >= 1) { if (numaNodeId == -1) { numaNodeId = executorId.toInt @@ -213,7 +213,7 @@ private[spark] class BlockManager( PersistentMemoryPlatform.initialize(file.getAbsolutePath, pmemInitialSize, 0) logInfo(s"Intel Optane PMem initialized with path: ${file.getAbsolutePath}, size: ${pmemInitialSize} ") } - } else if (pmemMode.equals("KMemDax")) { + } else if (memExtensionEnabled && pmemMode.equals("KMemDax")) { if (!isDriver) { if (numaNodeId == -1) { numaNodeId = (executorId.toInt + 1) % 2