diff --git a/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/src/main/java/org/apache/spark/memory/MemoryConsumer.java
new file mode 100644
index 00000000..12e360ac
--- /dev/null
+++ b/src/main/java/org/apache/spark/memory/MemoryConsumer.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory;
+
+import java.io.IOException;
+
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+/**
+ * A memory consumer of {@link TaskMemoryManager} that supports spilling.
+ *
+ * Note: this only supports allocation / spilling of Tungsten memory.
+ */
+public abstract class MemoryConsumer {
+
+ protected final TaskMemoryManager taskMemoryManager;
+ private final long pageSize;
+ private final MemoryMode mode;
+ protected long used;
+
+ protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize, MemoryMode mode) {
+ this.taskMemoryManager = taskMemoryManager;
+ this.pageSize = pageSize;
+ this.mode = mode;
+ }
+
+ protected MemoryConsumer(TaskMemoryManager taskMemoryManager) {
+ this(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP);
+ }
+
+ /**
+ * Returns the memory mode, {@link MemoryMode#ON_HEAP} or {@link MemoryMode#OFF_HEAP}.
+ */
+ public MemoryMode getMode() {
+ return mode;
+ }
+
+ /**
+ * Returns the size of used memory in bytes.
+ */
+ public long getUsed() {
+ return used;
+ }
+
+ /**
+ * Force spill during building.
+ */
+ public void spill() throws IOException {
+ spill(Long.MAX_VALUE, this);
+ }
+
+ /**
+ * Spill some data to disk to release memory, which will be called by TaskMemoryManager
+ * when there is not enough memory for the task.
+ *
+ * This should be implemented by subclass.
+ *
+ * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill().
+ *
+ * Note: today, this only frees Tungsten-managed pages.
+ *
+ * @param size the amount of memory should be released
+ * @param trigger the MemoryConsumer that trigger this spilling
+ * @return the amount of released memory in bytes
+ */
+ public abstract long spill(long size, MemoryConsumer trigger) throws IOException;
+
+ /**
+ * Allocates a LongArray of `size`. Note that this method may throw `SparkOutOfMemoryError`
+ * if Spark doesn't have enough memory for this allocation, or throw `TooLargePageException`
+ * if this `LongArray` is too large to fit in a single page. The caller side should take care of
+ * these two exceptions, or make sure the `size` is small enough that won't trigger exceptions.
+ *
+ * @throws SparkOutOfMemoryError
+ * @throws TooLargePageException
+ */
+ public LongArray allocateArray(long size) {
+ long required = size * 8L;
+ MemoryBlock page = taskMemoryManager.allocatePage(required, this);
+ if (page == null || page.size() < required) {
+ throwOom(page, required);
+ }
+ used += required;
+ return new LongArray(page);
+ }
+
+ /**
+ * Frees a LongArray.
+ */
+ public void freeArray(LongArray array) {
+ freePage(array.memoryBlock());
+ }
+
+ /**
+ * Allocate a memory block with at least `required` bytes.
+ *
+ * @throws SparkOutOfMemoryError
+ */
+ protected MemoryBlock allocatePage(long required) {
+ MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this);
+ if (page == null || page.size() < required) {
+ throwOom(page, required);
+ }
+ used += page.size();
+ return page;
+ }
+
+ /**
+ * Free a memory block.
+ */
+ protected void freePage(MemoryBlock page) {
+ used -= page.size();
+ taskMemoryManager.freePage(page, this);
+ }
+
+ /**
+ * Allocates memory of `size`.
+ */
+ public long acquireMemory(long size) {
+ long granted = taskMemoryManager.acquireExecutionMemory(size, this);
+ used += granted;
+ return granted;
+ }
+
+ /**
+ * Release N bytes of memory.
+ */
+ public void freeMemory(long size) {
+ taskMemoryManager.releaseExecutionMemory(size, this);
+ used -= size;
+ }
+
+ private void throwOom(final MemoryBlock page, final long required) {
+ long got = 0;
+ if (page != null) {
+ got = page.size();
+ taskMemoryManager.freePage(page, this);
+ }
+ taskMemoryManager.showMemoryUsage();
+ // checkstyle.off: RegexpSinglelineJava
+ throw new SparkOutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " +
+ got);
+ // checkstyle.on: RegexpSinglelineJava
+ }
+
+ public TaskMemoryManager getTaskMemoryManager() {
+ return taskMemoryManager;
+ }
+}
diff --git a/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
new file mode 100644
index 00000000..6bbe33e0
--- /dev/null
+++ b/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -0,0 +1,619 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory;
+
+import javax.annotation.concurrent.GuardedBy;
+import java.io.IOException;
+import java.nio.channels.ClosedByInterruptException;
+import java.util.Arrays;
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.util.Utils;
+
+/**
+ * Manages the memory allocated by an individual task.
+ *
+ * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs.
+ * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is
+ * addressed by the combination of a base Object reference and a 64-bit offset within that object.
+ * This is a problem when we want to store pointers to data structures inside of other structures,
+ * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits
+ * to address memory, we can't just store the address of the base object since it's not guaranteed
+ * to remain stable as the heap gets reorganized due to GC.
+ *
+ * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap
+ * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to
+ * store a "page number" and the lower 51 bits to store an offset within this page. These page
+ * numbers are used to index into a "page table" array inside of the MemoryManager in order to
+ * retrieve the base object.
+ *
+ * This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the
+ * maximum size of a long[] array, allowing us to address 8192 * (2^31 - 1) * 8 bytes, which is
+ * approximately 140 terabytes of memory.
+ */
+public class TaskMemoryManager {
+
+ private static final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class);
+
+ /** The number of bits used to address the page table. */
+ private static final int PAGE_NUMBER_BITS = 13;
+
+ /** The number of bits used to encode offsets in data pages. */
+ @VisibleForTesting
+ static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51
+
+ /** The number of entries in the page table. */
+ private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
+
+ /**
+ * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is
+ * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's
+ * maximum page size is limited by the maximum amount of data that can be stored in a long[]
+ * array, which is (2^31 - 1) * 8 bytes (or about 17 gigabytes). Therefore, we cap this at 17
+ * gigabytes.
+ */
+ public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L;
+
+ /** Bit mask for the lower 51 bits of a long. */
+ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
+
+ /**
+ * Similar to an operating system's page table, this array maps page numbers into base object
+ * pointers, allowing us to translate between the hashtable's internal 64-bit address
+ * representation and the baseObject+offset representation which we use to support both on- and
+ * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`.
+ * When using an on-heap allocator, the entries in this map will point to pages' base objects.
+ * Entries are added to this map as new data pages are allocated.
+ */
+ private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE];
+
+ /**
+ * Bitmap for tracking free pages.
+ */
+ private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE);
+
+ private final MemoryManager memoryManager;
+
+ private final long taskAttemptId;
+
+ /**
+ * Tracks whether we're on-heap or off-heap. For off-heap, we short-circuit most of these methods
+ * without doing any masking or lookups. Since this branching should be well-predicted by the JIT,
+ * this extra layer of indirection / abstraction hopefully shouldn't be too expensive.
+ */
+ final MemoryMode tungstenMemoryMode;
+
+ /**
+ * Tracks spillable memory consumers.
+ */
+ @GuardedBy("this")
+ private final HashSet consumers;
+
+ /**
+ * The amount of memory that is acquired but not used.
+ */
+ private volatile long acquiredButNotUsed = 0L;
+
+ /**
+ * Construct a new TaskMemoryManager.
+ */
+ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) {
+ this.tungstenMemoryMode = memoryManager.tungstenMemoryMode();
+ this.memoryManager = memoryManager;
+ this.taskAttemptId = taskAttemptId;
+ this.consumers = new HashSet<>();
+ }
+
+ /**
+ * Acquire N bytes of memory for a consumer. If there is no enough memory, it will call
+ * spill() of consumers to release more memory.
+ *
+ * @return number of bytes successfully granted (<= N).
+ */
+ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
+ assert(required >= 0);
+ assert(consumer != null);
+ MemoryMode mode = consumer.getMode();
+ // If we are allocating Tungsten pages off-heap and receive a request to allocate on-heap
+ // memory here, then it may not make sense to spill since that would only end up freeing
+ // off-heap memory. This is subject to change, though, so it may be risky to make this
+ // optimization now in case we forget to undo it late when making changes.
+ synchronized (this) {
+ long got = memoryManager.acquireExecutionMemory(required, taskAttemptId, mode);
+
+ // Try to release memory from other consumers first, then we can reduce the frequency of
+ // spilling, avoid to have too many spilled files.
+ if (got < required) {
+ // Call spill() on other consumers to release memory
+ // Sort the consumers according their memory usage. So we avoid spilling the same consumer
+ // which is just spilled in last few times and re-spilling on it will produce many small
+ // spill files.
+ TreeMap> sortedConsumers = new TreeMap<>();
+ for (MemoryConsumer c: consumers) {
+ if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) {
+ long key = c.getUsed();
+ List list =
+ sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1));
+ list.add(c);
+ }
+ }
+ while (!sortedConsumers.isEmpty()) {
+ // Get the consumer using the least memory more than the remaining required memory.
+ Map.Entry> currentEntry =
+ sortedConsumers.ceilingEntry(required - got);
+ // No consumer has used memory more than the remaining required memory.
+ // Get the consumer of largest used memory.
+ if (currentEntry == null) {
+ currentEntry = sortedConsumers.lastEntry();
+ }
+ List cList = currentEntry.getValue();
+ MemoryConsumer c = cList.get(cList.size() - 1);
+ try {
+ long released = c.spill(required - got, consumer);
+ if (released > 0) {
+ logger.debug("Task {} released {} from {} for {}", taskAttemptId,
+ Utils.bytesToString(released), c, consumer);
+ got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode);
+ if (got >= required) {
+ break;
+ }
+ } else {
+ cList.remove(cList.size() - 1);
+ if (cList.isEmpty()) {
+ sortedConsumers.remove(currentEntry.getKey());
+ }
+ }
+ } catch (ClosedByInterruptException e) {
+ // This called by user to kill a task (e.g: speculative task).
+ logger.error("error while calling spill() on " + c, e);
+ throw new RuntimeException(e.getMessage());
+ } catch (IOException e) {
+ logger.error("error while calling spill() on " + c, e);
+ // checkstyle.off: RegexpSinglelineJava
+ throw new SparkOutOfMemoryError("error while calling spill() on " + c + " : "
+ + e.getMessage());
+ // checkstyle.on: RegexpSinglelineJava
+ }
+ }
+ }
+
+ // call spill() on itself
+ if (got < required) {
+ try {
+ long released = consumer.spill(required - got, consumer);
+ if (released > 0) {
+ logger.debug("Task {} released {} from itself ({})", taskAttemptId,
+ Utils.bytesToString(released), consumer);
+ got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode);
+ }
+ } catch (ClosedByInterruptException e) {
+ // This called by user to kill a task (e.g: speculative task).
+ logger.error("error while calling spill() on " + consumer, e);
+ throw new RuntimeException(e.getMessage());
+ } catch (IOException e) {
+ logger.error("error while calling spill() on " + consumer, e);
+ // checkstyle.off: RegexpSinglelineJava
+ throw new SparkOutOfMemoryError("error while calling spill() on " + consumer + " : "
+ + e.getMessage());
+ // checkstyle.on: RegexpSinglelineJava
+ }
+ }
+
+ consumers.add(consumer);
+ logger.debug("Task {} acquired {} for {}", taskAttemptId, Utils.bytesToString(got), consumer);
+ return got;
+ }
+ }
+
+ /**
+ * Acquire extended memory
+ * When extended memory is acqurired, spill will not be triggered.
+ * @param required
+ * @return
+ */
+ public long acquireExtendedMemory(long required, MemoryConsumer consumer) {
+ assert(required >= 0);
+ logger.debug("Task {} acquire {} bytes PMem memory.", taskAttemptId, Utils.bytesToString(required));
+ synchronized (this) {
+ long got = memoryManager.acquireExtendedMemory(required, taskAttemptId);
+ logger.debug("Task {} got {} bytes PMem memory.", taskAttemptId, Utils.bytesToString(got));
+ // The MemoryConsumer which acquired extended memory should be traced in TaskMemoryManagr.
+ // Not very sure about whether it should be added to the consumers here. Maybe should maintain
+ // another list for consumers which use extended memory.
+ consumers.add(consumer);
+ return got;
+ }
+ }
+
+ /**
+ * Release N bytes of execution memory for a MemoryConsumer.
+ */
+ public void releaseExecutionMemory(long size, MemoryConsumer consumer) {
+ logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer);
+ memoryManager.releaseExecutionMemory(size, taskAttemptId, consumer.getMode());
+ }
+
+ public long acquireExtendedMemory(long required) {
+ assert(required >= 0);
+ logger.debug("Task {} acquire {} bytes PMem memory.", taskAttemptId, Utils.bytesToString(required));
+ synchronized (this) {
+ long got = memoryManager.acquireExtendedMemory(required, taskAttemptId);
+ return got;
+ }
+ }
+
+ public void releaseExtendedMemory(long size) {
+ logger.debug("Task {} release {} PMem space.", taskAttemptId, Utils.bytesToString(size));
+ memoryManager.releaseExtendedMemory(size, taskAttemptId);
+ }
+ /**
+ * Rlease extended memory
+ * @param size
+ */
+ public void releaseExtendedMemory(long size, MemoryConsumer consumer) {
+ logger.debug("Task {} release {} PMem space.", taskAttemptId, Utils.bytesToString(size));
+ memoryManager.releaseExtendedMemory(size, taskAttemptId);
+ }
+
+ /**
+ * Dump the memory usage of all consumers.
+ */
+ public void showMemoryUsage() {
+ logger.info("Memory used in task " + taskAttemptId);
+ synchronized (this) {
+ long memoryAccountedForByConsumers = 0;
+ for (MemoryConsumer c: consumers) {
+ long totalMemUsage = c.getUsed();
+ memoryAccountedForByConsumers += totalMemUsage;
+ if (totalMemUsage > 0) {
+ logger.info("Acquired by " + c + ": " + Utils.bytesToString(totalMemUsage));
+ }
+ }
+ long memoryNotAccountedFor =
+ memoryManager.getExecutionMemoryUsageForTask(taskAttemptId) - memoryAccountedForByConsumers;
+ logger.info(
+ "{} bytes of memory were used by task {} but are not associated with specific consumers",
+ memoryNotAccountedFor, taskAttemptId);
+ logger.info(
+ "{} bytes of memory are used for execution and {} bytes of memory are used for storage",
+ memoryManager.executionMemoryUsed(), memoryManager.storageMemoryUsed());
+ }
+ }
+
+ /**
+ * Return the page size in bytes.
+ */
+ public long pageSizeBytes() {
+ return memoryManager.pageSizeBytes();
+ }
+
+ /**
+ * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is
+ * intended for allocating large blocks of Tungsten memory that will be shared between operators.
+ *
+ * Returns `null` if there was not enough memory to allocate the page. May return a page that
+ * contains fewer bytes than requested, so callers should verify the size of returned pages.
+ *
+ * @throws TooLargePageException
+ */
+ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
+ return allocatePage(size, consumer, false);
+ }
+
+ public MemoryBlock allocatePage(long size, MemoryConsumer consumer, boolean useExtendedMem) {
+ assert(consumer != null);
+ assert(consumer.getMode() == tungstenMemoryMode);
+ if (size > MAXIMUM_PAGE_SIZE_BYTES) {
+ throw new TooLargePageException(size);
+ }
+
+ long acquired = 0L;
+ if (useExtendedMem) {
+ acquired = acquireExtendedMemory(size, consumer);
+ } else {
+ acquired = acquireExecutionMemory(size, consumer);
+ }
+ if (acquired <= 0) {
+ return null;
+ }
+
+ final int pageNumber;
+ synchronized (this) {
+ pageNumber = allocatedPages.nextClearBit(0);
+ if (pageNumber >= PAGE_TABLE_SIZE) {
+ if (useExtendedMem) {
+ releaseExtendedMemory(acquired, consumer);
+ } else {
+ releaseExecutionMemory(acquired, consumer);
+ }
+ throw new IllegalStateException(
+ "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
+ }
+ allocatedPages.set(pageNumber);
+ }
+ MemoryBlock page = null;
+ try {
+ if (useExtendedMem) {
+ page = memoryManager.extendedMemoryAllocator().allocate(size);
+ page.isExtendedMemory(true);
+ } else {
+ page = memoryManager.tungstenMemoryAllocator().allocate(acquired);
+ }
+ } catch (OutOfMemoryError e) {
+ logger.warn("Failed to allocate a page ({} bytes), try again.", acquired);
+ // there is no enough memory actually, it means the actual free memory is smaller than
+ // MemoryManager thought, we should keep the acquired memory.
+ synchronized (this) {
+ acquiredButNotUsed += acquired;
+ allocatedPages.clear(pageNumber);
+ }
+ if (useExtendedMem) {
+ // will not force spill when use extended mem
+ return null;
+ } else {
+ // this could trigger spilling to free some pages.
+ return allocatePage(size, consumer);
+ }
+ }
+
+ page.pageNumber = pageNumber;
+ pageTable[pageNumber] = page;
+ if (logger.isTraceEnabled()) {
+ logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired);
+ }
+ return page;
+ }
+
+ /**
+ * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}.
+ */
+ public void freePage(MemoryBlock page, MemoryConsumer consumer) {
+ assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) :
+ "Called freePage() on memory that wasn't allocated with allocatePage()";
+ assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+ "Called freePage() on a memory block that has already been freed";
+ assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) :
+ "Called freePage() on a memory block that has already been freed";
+ assert(allocatedPages.get(page.pageNumber));
+ pageTable[page.pageNumber] = null;
+ synchronized (this) {
+ allocatedPages.clear(page.pageNumber);
+ }
+ if (logger.isTraceEnabled()) {
+ logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
+ }
+ long pageSize = page.size();
+ // Clear the page number before passing the block to the MemoryAllocator's free().
+ // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed
+ // page has been inappropriately directly freed without calling TMM.freePage().
+ page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
+ if (page.isExtendedMemory) {
+ memoryManager.extendedMemoryAllocator().free(page);
+ releaseExtendedMemory(pageSize, consumer);
+ } else {
+ memoryManager.tungstenMemoryAllocator().free(page);
+ releaseExecutionMemory(pageSize, consumer);
+ }
+ }
+
+ /**
+ * @param size
+ * @return
+ */
+ public MemoryBlock allocatePMemPage(long size) {
+ if (size > MAXIMUM_PAGE_SIZE_BYTES) {
+ throw new TooLargePageException(size);
+ }
+ final int pageNumber;
+ synchronized (this) {
+ pageNumber = allocatedPages.nextClearBit(0);
+ if (pageNumber >= PAGE_TABLE_SIZE) {
+ throw new IllegalStateException(
+ "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
+ }
+ allocatedPages.set(pageNumber);
+ }
+ MemoryBlock page = null;
+ try {
+ page = memoryManager.getUsablePMemPage(size);
+ if (page == null) {
+ page = memoryManager.extendedMemoryAllocator().allocate(size);
+ memoryManager.addPMemPages(page);
+ } else {
+ logger.debug("reuse pmem page.");
+ }
+ } catch (OutOfMemoryError e) {
+ logger.debug("Failed to allocate a PMem page ({} bytes).", size);
+ return null;
+ }
+ page.isExtendedMemory(true);
+ page.pageNumber = pageNumber;
+ pageTable[pageNumber] = page;
+ if (logger.isTraceEnabled()) {
+ logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
+ }
+ return page;
+
+ }
+
+ public void freePMemPage(MemoryBlock page, MemoryConsumer consumer) {
+ assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) :
+ "Called freePage() on memory that wasn't allocated with allocatePage()";
+ assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) :
+ "Called freePage() on a memory block that has already been freed";
+ assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) :
+ "Called freePage() on a memory block that has already been freed";
+ assert(allocatedPages.get(page.pageNumber));
+ pageTable[page.pageNumber] = null;
+ synchronized (this) {
+ allocatedPages.clear(page.pageNumber);
+ }
+ if (logger.isTraceEnabled()) {
+ logger.trace("Freed PMem page number {} ({} bytes)", page.pageNumber, page.size());
+ }
+ // Clear the page number before passing the block to the MemoryAllocator's free().
+ // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed
+ // page has been inappropriately directly freed without calling TMM.freePage().
+
+ page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
+ // not really free the PMem page for future page reuse
+ }
+
+ /**
+ * Given a memory page and offset within that page, encode this address into a 64-bit long.
+ * This address will remain valid as long as the corresponding page has not been freed.
+ *
+ * @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/
+ * @param offsetInPage an offset in this page which incorporates the base offset. In other words,
+ * this should be the value that you would pass as the base offset into an
+ * UNSAFE call (e.g. page.baseOffset() + something).
+ * @return an encoded page address.
+ */
+ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
+ if (tungstenMemoryMode == MemoryMode.OFF_HEAP) {
+ // In off-heap mode, an offset is an absolute address that may require a full 64 bits to
+ // encode. Due to our page size limitation, though, we can convert this into an offset that's
+ // relative to the page's base offset; this relative offset will fit in 51 bits.
+ offsetInPage -= page.getBaseOffset();
+ }
+ return encodePageNumberAndOffset(page.pageNumber, offsetInPage);
+ }
+
+ @VisibleForTesting
+ public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
+ assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid page";
+ return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
+ }
+
+ @VisibleForTesting
+ public static int decodePageNumber(long pagePlusOffsetAddress) {
+ return (int) (pagePlusOffsetAddress >>> OFFSET_BITS);
+ }
+
+ private static long decodeOffset(long pagePlusOffsetAddress) {
+ return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
+ }
+
+ /**
+ * Get the page associated with an address encoded by
+ * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
+ */
+ public Object getPage(long pagePlusOffsetAddress) {
+ if (tungstenMemoryMode == MemoryMode.ON_HEAP) {
+ final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
+ assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+ final MemoryBlock page = pageTable[pageNumber];
+ assert (page != null);
+ assert (page.getBaseObject() != null);
+ return page.getBaseObject();
+ } else {
+ return null;
+ }
+ }
+
+ public MemoryBlock getOriginalPage(long pagePlusOffsetAddress) {
+ if (tungstenMemoryMode == MemoryMode.ON_HEAP) {
+ final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
+ assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+ final MemoryBlock page = pageTable[pageNumber];
+ assert (page != null);
+ assert (page.getBaseObject() != null);
+ return page;
+ } else {
+ return null;
+ }
+ }
+
+ /**
+ * Get the offset associated with an address encoded by
+ * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
+ */
+ public long getOffsetInPage(long pagePlusOffsetAddress) {
+ final long offsetInPage = decodeOffset(pagePlusOffsetAddress);
+ if (tungstenMemoryMode == MemoryMode.ON_HEAP) {
+ return offsetInPage;
+ } else {
+ // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we
+ // converted the absolute address into a relative address. Here, we invert that operation:
+ final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
+ assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+ final MemoryBlock page = pageTable[pageNumber];
+ assert (page != null);
+ return page.getBaseOffset() + offsetInPage;
+ }
+ }
+
+ /**
+ * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return
+ * value can be used to detect memory leaks.
+ */
+ public long cleanUpAllAllocatedMemory() {
+ synchronized (this) {
+ for (MemoryConsumer c: consumers) {
+ if (c != null && c.getUsed() > 0) {
+ // In case of failed task, it's normal to see leaked memory
+ logger.debug("unreleased " + Utils.bytesToString(c.getUsed()) + " memory from " + c);
+ }
+ }
+ consumers.clear();
+
+ for (MemoryBlock page : pageTable) {
+ if (page != null) {
+ logger.debug("unreleased page: " + page + " in task " + taskAttemptId);
+ page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER;
+ if (!page.isExtendedMemory){
+ memoryManager.tungstenMemoryAllocator().free(page);
+ } else {
+ memoryManager.extendedMemoryAllocator().free(page);
+ }
+ }
+ }
+ Arrays.fill(pageTable, null);
+ }
+
+ // release the memory that is not used by any consumer (acquired for pages in tungsten mode).
+ memoryManager.releaseExecutionMemory(acquiredButNotUsed, taskAttemptId, tungstenMemoryMode);
+ memoryManager.releaseAllExtendedMemoryForTask(taskAttemptId);
+
+ return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
+ }
+
+ /**
+ * Returns the memory consumption, in bytes, for the current task.
+ */
+ public long getMemoryConsumptionForThisTask() {
+ return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId);
+ }
+
+ /**
+ * Returns Tungsten memory mode
+ */
+ public MemoryMode getTungstenMemoryMode() {
+ return tungstenMemoryMode;
+ }
+}
diff --git a/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index d4cc153e..e7463908 100644
--- a/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -25,10 +25,6 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
-import com.intel.oap.common.storage.stream.ChunkInputStream;
-import com.intel.oap.common.storage.stream.DataStore;
-import org.apache.spark.internal.config.package$;
-import org.apache.spark.memory.PMemManagerInitializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -286,7 +282,7 @@ private void advanceToNextPage() {
}
try {
Closeables.close(reader, /* swallowIOException = */ false);
- reader = spillWriters.getFirst().getReader(serializerManager);
+ reader = spillWriters.getFirst().getReader(serializerManager, null);
recordsInPage = -1;
} catch (IOException e) {
// Scala iterator does not handle exception
@@ -397,23 +393,10 @@ public void remove() {
}
private void handleFailedDelete() {
- // remove the spill file from disk or pmem
- boolean pMemSpillEnabled = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get(
- package$.MODULE$.MEMORY_SPILL_PMEM_ENABLED());
+ // remove the spill file from disk
File file = spillWriters.removeFirst().getFile();
- if(pMemSpillEnabled == true) {
- try {
- ChunkInputStream cis = ChunkInputStream.getChunkInputStreamInstance(file.toString(),
- new DataStore(PMemManagerInitializer.getPMemManager(),
- PMemManagerInitializer.getProperties()));
- cis.free();
- } catch (IOException e) {
- logger.debug(e.toString());
- }
- } else {
- if (file != null && file.exists() && !file.delete()) {
- logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
- }
+ if (file != null && file.exists() && !file.delete()) {
+ logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
}
}
}
@@ -835,22 +818,6 @@ public void free() {
}
assert(dataPages.isEmpty());
- deleteSpillFiles();
- }
-
- /**
- * Deletes any spill files created by this consumer.
- */
- private void deleteSpillFiles() {
- boolean pMemSpillEnabled = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get(
- package$.MODULE$.MEMORY_SPILL_PMEM_ENABLED());
- if (pMemSpillEnabled == true)
- deletePMemSpillFiles();
- else
- deleteDiskSpillFiles();
- }
-
- private void deleteDiskSpillFiles() {
while (!spillWriters.isEmpty()) {
File file = spillWriters.removeFirst().getFile();
if (file != null && file.exists()) {
@@ -861,20 +828,6 @@ private void deleteDiskSpillFiles() {
}
}
- private void deletePMemSpillFiles() {
- while (!spillWriters.isEmpty()) {
- File file = spillWriters.removeFirst().getFile();
- try {
- ChunkInputStream cis = ChunkInputStream.getChunkInputStreamInstance(file.toString(),
- new DataStore(PMemManagerInitializer.getPMemManager(),
- PMemManagerInitializer.getProperties()));
- cis.free();
- } catch (IOException e) {
- logger.debug(e.toString());
- }
- }
- }
-
public TaskMemoryManager getTaskMemoryManager() {
return taskMemoryManager;
}
diff --git a/src/main/java/org/apache/spark/unsafe/memory/ExtendedMemoryAllocator.java b/src/main/java/org/apache/spark/unsafe/memory/ExtendedMemoryAllocator.java
new file mode 100644
index 00000000..7223068a
--- /dev/null
+++ b/src/main/java/org/apache/spark/unsafe/memory/ExtendedMemoryAllocator.java
@@ -0,0 +1,21 @@
+
+package org.apache.spark.unsafe.memory;
+import com.intel.oap.common.unsafe.PersistentMemoryPlatform;
+
+public class ExtendedMemoryAllocator implements MemoryAllocator{
+
+ @Override
+ public MemoryBlock allocate(long size) throws OutOfMemoryError {
+ long address = PersistentMemoryPlatform.allocateVolatileMemory(size);
+ MemoryBlock memoryBlock = new MemoryBlock(null, address, size);
+
+ return memoryBlock;
+ }
+
+ @Override
+ public void free(MemoryBlock memoryBlock) {
+ assert (memoryBlock.getBaseObject() == null) :
+ "baseObject not null; are you trying to use the AEP-heap allocator to free on-heap memory?";
+ PersistentMemoryPlatform.freeMemory(memoryBlock.getBaseOffset());
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java
new file mode 100644
index 00000000..b47370f4
--- /dev/null
+++ b/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+public interface MemoryAllocator {
+
+ /**
+ * Whether to fill newly allocated and deallocated memory with 0xa5 and 0x5a bytes respectively.
+ * This helps catch misuse of uninitialized or freed memory, but imposes some overhead.
+ */
+ boolean MEMORY_DEBUG_FILL_ENABLED = Boolean.parseBoolean(
+ System.getProperty("spark.memory.debugFill", "false"));
+
+ // Same as jemalloc's debug fill values.
+ byte MEMORY_DEBUG_FILL_CLEAN_VALUE = (byte)0xa5;
+ byte MEMORY_DEBUG_FILL_FREED_VALUE = (byte)0x5a;
+
+ /**
+ * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed
+ * to be zeroed out (call `fill(0)` on the result if this is necessary).
+ */
+ MemoryBlock allocate(long size) throws OutOfMemoryError;
+
+ void free(MemoryBlock memory);
+
+ MemoryAllocator UNSAFE = new UnsafeMemoryAllocator();
+
+ MemoryAllocator HEAP = new HeapMemoryAllocator();
+
+ MemoryAllocator EXTENDED = new ExtendedMemoryAllocator();
+}
diff --git a/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
new file mode 100644
index 00000000..7517b3a0
--- /dev/null
+++ b/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+import javax.annotation.Nullable;
+
+import org.apache.spark.unsafe.Platform;
+
+/**
+ * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size.
+ */
+public class MemoryBlock extends MemoryLocation {
+
+ /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */
+ public static final int NO_PAGE_NUMBER = -1;
+
+ /**
+ * Special `pageNumber` value for marking pages that have been freed in the TaskMemoryManager.
+ * We set `pageNumber` to this value in TaskMemoryManager.freePage() so that MemoryAllocator
+ * can detect if pages which were allocated by TaskMemoryManager have been freed in the TMM
+ * before being passed to MemoryAllocator.free() (it is an error to allocate a page in
+ * TaskMemoryManager and then directly free it in a MemoryAllocator without going through
+ * the TMM freePage() call).
+ */
+ public static final int FREED_IN_TMM_PAGE_NUMBER = -2;
+
+ /**
+ * Special `pageNumber` value for pages that have been freed by the MemoryAllocator. This allows
+ * us to detect double-frees.
+ */
+ public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3;
+
+ private final long length;
+
+ /**
+ * Optional page number; used when this MemoryBlock represents a page allocated by a
+ * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager,
+ * which lives in a different package.
+ */
+ public int pageNumber = NO_PAGE_NUMBER;
+
+ /**
+ * Indicate the memory block is on extended memory or not.
+ */
+ public boolean isExtendedMemory = false;
+
+ public MemoryBlock(@Nullable Object obj, long offset, long length) {
+ super(obj, offset);
+ this.length = length;
+ }
+
+ /**
+ * Returns the size of the memory block.
+ */
+ public long size() {
+ return length;
+ }
+
+ /**
+ * Creates a memory block pointing to the memory used by the long array.
+ */
+ public static MemoryBlock fromLongArray(final long[] array) {
+ return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L);
+ }
+
+ /**
+ * Fills the memory block with the specified byte value.
+ */
+ public void fill(byte value) {
+ Platform.setMemory(obj, offset, length, value);
+ }
+
+ /**
+ * Whether this memory block is on extended memory.
+ * @return
+ */
+ public boolean isExtendedMemory() {
+ return isExtendedMemory;
+ }
+
+ /**
+ * set whether this memory block is on extended memory.
+ * @param isExtended
+ */
+ public void isExtendedMemory(boolean isExtended) {
+ this.isExtendedMemory = isExtended;
+ }
+}
diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReader.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReader.java
new file mode 100644
index 00000000..1d2b1ee3
--- /dev/null
+++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReader.java
@@ -0,0 +1,87 @@
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+import java.io.Closeable;
+import java.util.LinkedList;
+
+public final class PMemReader extends UnsafeSorterIterator implements Closeable {
+ private int recordLength;
+ private long keyPrefix;
+ private int numRecordsRemaining;
+ private int numRecords;
+ private LinkedList pMemPages;
+ private MemoryBlock pMemPage = null;
+ private int readingPageIndex = 0;
+ private int readedRecordsInCurrentPage = 0;
+ private int numRecordsInpage = 0;
+ private long offset = 0;
+ private byte[] arr = new byte[1024 * 1024];
+ private Object baseObject = arr;
+ public PMemReader(LinkedList pMemPages, int numRecords) {
+ this.pMemPages = pMemPages;
+ this.numRecordsRemaining = this.numRecords = numRecords;
+ }
+ @Override
+ public void loadNext() {
+ assert (readingPageIndex <= pMemPages.size())
+ : "Illegal state: Pages finished read but hasNext() is true.";
+ if(pMemPage == null || readedRecordsInCurrentPage == numRecordsInpage) {
+ // read records from each page
+ pMemPage = pMemPages.get(readingPageIndex++);
+ readedRecordsInCurrentPage = 0;
+ numRecordsInpage = Platform.getInt(null, pMemPage.getBaseOffset());
+ offset = pMemPage.getBaseOffset() + 4;
+ }
+ // record: BaseOffSet, record length, KeyPrefix, record value
+ keyPrefix = Platform.getLong(null, offset);
+ offset += 8;
+ recordLength = Platform.getInt(null, offset);
+ offset += 4;
+ if (recordLength > arr.length) {
+ arr = new byte[recordLength];
+ baseObject = arr;
+ }
+ Platform.copyMemory(null, offset , baseObject, Platform.BYTE_ARRAY_OFFSET, recordLength);
+ offset += recordLength;
+ readedRecordsInCurrentPage ++;
+ numRecordsRemaining --;
+
+
+ }
+ @Override
+ public int getNumRecords() {
+ return numRecords;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return (numRecordsRemaining > 0);
+ }
+
+ @Override
+ public Object getBaseObject() {
+ return baseObject;
+ }
+
+ @Override
+ public long getBaseOffset() {
+ return Platform.BYTE_ARRAY_OFFSET;
+ }
+
+ @Override
+ public int getRecordLength() {
+ return recordLength;
+ }
+
+ @Override
+ public long getKeyPrefix() {
+ return keyPrefix;
+ }
+
+ @Override
+ public void close() {
+ // do nothing here
+ }
+}
diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReaderForUnsafeExternalSorter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReaderForUnsafeExternalSorter.java
new file mode 100644
index 00000000..f970b0dd
--- /dev/null
+++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReaderForUnsafeExternalSorter.java
@@ -0,0 +1,141 @@
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.SparkEnv;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Closeable;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+public final class PMemReaderForUnsafeExternalSorter extends UnsafeSorterIterator implements Closeable {
+ private static final Logger logger = LoggerFactory.getLogger(PMemReaderForUnsafeExternalSorter.class);
+ private int recordLength;
+ private long keyPrefix;
+ private int numRecordsRemaining;
+ private int numRecords;
+ private LongArray sortedArray;
+ private int position;
+ private byte[] arr = new byte[1024 * 1024];
+ private byte[] bytes = new byte[1024 * 1024];
+ private Object baseObject = arr;
+ private TaskMetrics taskMetrics;
+ private long startTime;
+ private ByteBuffer byteBuffer;
+ public PMemReaderForUnsafeExternalSorter(
+ LongArray sortedArray, int position, int numRecords, TaskMetrics taskMetrics) {
+ this.sortedArray = sortedArray;
+ this.position = position;
+ this.numRecords = numRecords;
+ this.numRecordsRemaining = numRecords - position/2;
+ this.taskMetrics = taskMetrics;
+ int readBufferSize = SparkEnv.get() == null? 8 * 1024 * 1024 :
+ (int) SparkEnv.get().conf().get(package$.MODULE$.MEMORY_SPILL_PMEM_READ_BUFFERSIZE());
+ logger.info("PMem read buffer size is:" + Utils.bytesToString(readBufferSize));
+ this.byteBuffer = ByteBuffer.wrap(new byte[readBufferSize]);
+ byteBuffer.flip();
+ byteBuffer.order(ByteOrder.nativeOrder());
+ }
+
+ @Override
+ public void loadNext() {
+ if (!byteBuffer.hasRemaining()) {
+ boolean refilled = refill();
+ if (!refilled) {
+ logger.error("Illegal status: records finished read but hasNext() is true.");
+ }
+ }
+ keyPrefix = byteBuffer.getLong();
+ recordLength = byteBuffer.getInt();
+ if (recordLength > arr.length) {
+ arr = new byte[recordLength];
+ baseObject = arr;
+ }
+ byteBuffer.get(arr, 0, recordLength);
+ numRecordsRemaining --;
+ }
+
+ @Override
+ public int getNumRecords() {
+ return numRecords;
+ }
+
+ /**
+ * load more PMem records in the buffer
+ */
+ private boolean refill() {
+ byteBuffer.clear();
+ int nRead = loadData();
+ byteBuffer.flip();
+ if (nRead <= 0) {
+ return false;
+ }
+ return true;
+ }
+
+ private int loadData() {
+ // no records remaining to read
+ if (position >= numRecords * 2)
+ return -1;
+ int bufferPos = 0;
+ int capacity = byteBuffer.capacity();
+ while (bufferPos < capacity && position < numRecords * 2) {
+ long curRecordAddress = sortedArray.get(position);
+ int recordLen = Platform.getInt(null, curRecordAddress);
+ // length + keyprefix + record length
+ int length = Integer.BYTES + Long.BYTES + recordLen;
+ if (length > capacity) {
+ logger.error("single record size exceeds PMem read buffer. Please increase buffer size.");
+ }
+ if (bufferPos + length <= capacity) {
+ long curKeyPrefix = sortedArray.get(position + 1);
+ if (length > bytes.length) {
+ bytes = new byte[length];
+ }
+ Platform.putLong(bytes, Platform.BYTE_ARRAY_OFFSET, curKeyPrefix);
+ Platform.copyMemory(null, curRecordAddress, bytes, Platform.BYTE_ARRAY_OFFSET + Long.BYTES, length - Long.BYTES);
+ byteBuffer.put(bytes, 0, length);
+ bufferPos += length;
+ position += 2;
+ } else {
+ break;
+ }
+ }
+ return bufferPos;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return (numRecordsRemaining > 0);
+ }
+
+ @Override
+ public Object getBaseObject() {
+ return baseObject;
+ }
+
+ @Override
+ public long getBaseOffset() {
+ return Platform.BYTE_ARRAY_OFFSET;
+ }
+
+ @Override
+ public int getRecordLength() {
+ return recordLength;
+ }
+
+ @Override
+ public long getKeyPrefix() {
+ return keyPrefix;
+ }
+
+ @Override
+ public void close() {
+ // do nothing here
+ }
+}
diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemSpillWriterFactory.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemSpillWriterFactory.java
new file mode 100644
index 00000000..97d7c61d
--- /dev/null
+++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemSpillWriterFactory.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.serializer.SerializerManager;
+import org.apache.spark.storage.BlockManager;
+
+import java.io.IOException;
+
+public class PMemSpillWriterFactory {
+ public static SpillWriterForUnsafeSorter getSpillWriter(
+ PMemSpillWriterType writerType,
+ UnsafeExternalSorter externalSorter,
+ UnsafeSorterIterator sortedIterator,
+ int numberOfRecordsToWritten,
+ SerializerManager serializerManager,
+ BlockManager blockManager,
+ int fileBufferSize,
+ ShuffleWriteMetrics writeMetrics,
+ TaskMetrics taskMetrics,
+ boolean spillToPMEMEnabled,
+ boolean isSorted) throws IOException {
+ if (spillToPMEMEnabled && writerType == PMemSpillWriterType.MEM_COPY_ALL_DATA_PAGES_TO_PMEM_WITHLONGARRAY){
+ SortedIteratorForSpills sortedSpillIte = SortedIteratorForSpills.createFromExistingSorterIte(
+ (UnsafeInMemorySorter.SortedIterator)sortedIterator,
+ externalSorter.getInMemSorter());
+ return new PMemWriter(
+ externalSorter,
+ sortedSpillIte,
+ isSorted,
+ numberOfRecordsToWritten,
+ serializerManager,
+ blockManager,
+ fileBufferSize,
+ writeMetrics,
+ taskMetrics);
+ } else {
+ if (sortedIterator == null) {
+ sortedIterator = externalSorter.getInMemSorter().getSortedIterator();
+ }
+ if (spillToPMEMEnabled && sortedIterator instanceof UnsafeInMemorySorter.SortedIterator){
+
+ if (writerType == PMemSpillWriterType.WRITE_SORTED_RECORDS_TO_PMEM) {
+ SortedIteratorForSpills sortedSpillIte = SortedIteratorForSpills.createFromExistingSorterIte(
+ (UnsafeInMemorySorter.SortedIterator)sortedIterator,
+ externalSorter.getInMemSorter());
+ return new SortedPMemPageSpillWriter(
+ externalSorter,
+ sortedSpillIte,
+ numberOfRecordsToWritten,
+ serializerManager,
+ blockManager,
+ fileBufferSize,
+ writeMetrics,
+ taskMetrics);
+ }
+ if (spillToPMEMEnabled && writerType == PMemSpillWriterType.STREAM_SPILL_TO_PMEM) {
+ return new UnsafeSorterStreamSpillWriter(
+ blockManager,
+ fileBufferSize,
+ sortedIterator,
+ numberOfRecordsToWritten,
+ serializerManager,
+ writeMetrics,
+ taskMetrics);
+ }
+ } else {
+ return new UnsafeSorterSpillWriter(
+ blockManager,
+ fileBufferSize,
+ sortedIterator,
+ numberOfRecordsToWritten,
+ serializerManager,
+ writeMetrics,
+ taskMetrics);
+ }
+
+ }
+ return null;
+ }
+}
diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemSpillWriterType.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemSpillWriterType.java
new file mode 100644
index 00000000..e6bad61d
--- /dev/null
+++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemSpillWriterType.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+public enum PMemSpillWriterType {
+ STREAM_SPILL_TO_PMEM,
+ MEM_COPY_ALL_DATA_PAGES_TO_PMEM,
+ MEM_COPY_ALL_DATA_PAGES_TO_PMEM_WITHLONGARRAY,
+ WRITE_SORTED_RECORDS_TO_PMEM
+}
diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemWriter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemWriter.java
new file mode 100644
index 00000000..2db27343
--- /dev/null
+++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemWriter.java
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import com.intel.oap.common.unsafe.PersistentMemoryPlatform;
+
+import org.apache.spark.SparkEnv;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.serializer.SerializerManager;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
+/**
+ * In this writer, records page along with LongArray page are both dumped to PMem when spill happens
+ */
+public final class PMemWriter extends UnsafeSorterPMemSpillWriter {
+ private static final Logger logger = LoggerFactory.getLogger(PMemWriter.class);
+ private LongArray sortedArray;
+ private HashMap pageMap = new HashMap<>();
+ private int position;
+ private LinkedList allocatedDramPages;
+ private MemoryBlock pMemPageForLongArray;
+ private UnsafeSorterSpillWriter diskSpillWriter;
+ private BlockManager blockManager;
+ private SerializerManager serializerManager;
+ private int fileBufferSize;
+ private boolean isSorted;
+ private int totalRecordsWritten;
+ private final boolean spillToPMemConcurrently = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get(
+ package$.MODULE$.MEMORY_SPILL_PMEM_SORT_BACKGROUND());
+ private final boolean pMemClflushEnabled = SparkEnv.get() != null &&
+ (boolean)SparkEnv.get().conf().get(package$.MODULE$.MEMORY_SPILL_PMEM_CLFLUSH_ENABLED());
+
+ public PMemWriter(
+ UnsafeExternalSorter externalSorter,
+ SortedIteratorForSpills sortedIterator,
+ boolean isSorted,
+ int numberOfRecordsToWritten,
+ SerializerManager serializerManager,
+ BlockManager blockManager,
+ int fileBufferSize,
+ ShuffleWriteMetrics writeMetrics,
+ TaskMetrics taskMetrics) {
+ // SortedIterator is null or readingIterator from UnsafeExternalSorter.
+ // But it isn't used in this PMemWriter, only for keep same constructor with other spill writers.
+ super(externalSorter, sortedIterator, numberOfRecordsToWritten, writeMetrics, taskMetrics);
+ this.allocatedDramPages = externalSorter.getAllocatedPages();
+ this.blockManager = blockManager;
+ this.serializerManager = serializerManager;
+ this.fileBufferSize = fileBufferSize;
+ this.isSorted = isSorted;
+ // In the case that spill happens when iterator isn't sorted yet, the valid records
+ // will be [0, inMemsorter.numRecords]. When iterator is sorted, the valid records will be
+ // [position/2, inMemsorter.numRecords]
+ this.totalRecordsWritten = externalSorter.getInMemSorter().numRecords();
+ }
+
+ @Override
+ public void write() throws IOException {
+ // write records based on externalsorter
+ // try to allocate all needed PMem pages before spill to PMem
+ UnsafeInMemorySorter inMemSorter = externalSorter.getInMemSorter();
+ if (allocatePMemPages(allocatedDramPages, inMemSorter.getArray().memoryBlock())) {
+ if (spillToPMemConcurrently && !isSorted) {
+ logger.info("Concurrent PMem write/records sort");
+ long writeDuration = 0;
+ ExecutorService executorService = Executors.newSingleThreadExecutor();
+ Future future = executorService.submit(()->dumpPagesToPMem());
+ externalSorter.getInMemSorter().getSortedIterator();
+ try {
+ writeDuration = future.get();
+ } catch (InterruptedException | ExecutionException e) {
+ logger.error(e.getMessage());
+ }
+ executorService.shutdownNow();
+ updateLongArray(inMemSorter.getArray(), totalRecordsWritten, 0);
+ } else if(!isSorted) {
+ dumpPagesToPMem();
+ // get sorted iterator
+ externalSorter.getInMemSorter().getSortedIterator();
+ // update LongArray
+ updateLongArray(inMemSorter.getArray(), totalRecordsWritten, 0);
+ } else {
+ dumpPagesToPMem();
+ // get sorted iterator
+ assert(sortedIterator != null);
+ updateLongArray(inMemSorter.getArray(), totalRecordsWritten, sortedIterator.getPosition());
+ }
+ } else {
+ // fallback to disk spill
+ if (diskSpillWriter == null) {
+ diskSpillWriter = new UnsafeSorterSpillWriter(
+ blockManager,
+ fileBufferSize,
+ sortedIterator,
+ numberOfRecordsToWritten,
+ serializerManager,
+ writeMetrics,
+ taskMetrics);
+ }
+ diskSpillWriter.write(false);
+ }
+ }
+
+ public boolean allocatePMemPages(LinkedList dramPages, MemoryBlock longArrayPage) {
+ for (MemoryBlock page: dramPages) {
+ MemoryBlock pMemBlock = taskMemoryManager.allocatePMemPage(page.size());
+ if (pMemBlock != null) {
+ allocatedPMemPages.add(pMemBlock);
+ pageMap.put(page, pMemBlock);
+ } else {
+ pageMap.clear();
+ return false;
+ }
+ }
+ pMemPageForLongArray = taskMemoryManager.allocatePMemPage(longArrayPage.size());
+ if (pMemPageForLongArray != null) {
+ allocatedPMemPages.add(pMemPageForLongArray);
+ pageMap.put(longArrayPage, pMemPageForLongArray);
+ } else {
+ pageMap.clear();
+ return false;
+ }
+ return (allocatedPMemPages.size() == dramPages.size() + 1);
+ }
+
+ private long dumpPagesToPMem() {
+ long dumpTime = System.nanoTime();
+ for (MemoryBlock page : allocatedDramPages) {
+ dumpPageToPMem(page);
+ }
+ long dumpDuration = System.nanoTime() - dumpTime;
+ return dumpDuration;
+
+ }
+
+ private void dumpPageToPMem(MemoryBlock page) {
+ MemoryBlock pMemBlock = pageMap.get(page);
+ PersistentMemoryPlatform.copyMemory(
+ page.getBaseObject(), page.getBaseOffset(),
+ null, pMemBlock.getBaseOffset(), page.size(),
+ pMemClflushEnabled);
+ writeMetrics.incBytesWritten(page.size());
+ }
+
+ public void updateLongArray(LongArray sortedArray, int numRecords, int position) {
+ this.position = position;
+ while (position < numRecords * 2){
+ // update recordPointer in this array
+ long originalRecordPointer = sortedArray.get(position);
+ MemoryBlock page = taskMemoryManager.getOriginalPage(originalRecordPointer);
+ long offset = taskMemoryManager.getOffsetInPage(originalRecordPointer) - page.getBaseOffset();
+ MemoryBlock pMemBlock = pageMap.get(page);
+ long pMemOffset = pMemBlock.getBaseOffset() + offset;
+ sortedArray.set(position, pMemOffset);
+ position += 2;
+ }
+ // copy the LongArray to PMem
+ MemoryBlock arrayBlock = sortedArray.memoryBlock();
+ MemoryBlock pMemBlock = pageMap.get(arrayBlock);
+ PersistentMemoryPlatform.copyMemory(
+ arrayBlock.getBaseObject(), arrayBlock.getBaseOffset(),
+ null, pMemBlock.getBaseOffset(), arrayBlock.size(),
+ pMemClflushEnabled);
+ writeMetrics.incBytesWritten(pMemBlock.size());
+ this.sortedArray = new LongArray(pMemBlock);
+ }
+
+ @Override
+ public UnsafeSorterIterator getSpillReader() throws IOException {
+ // TODO: consider partial spill to PMem + Disk.
+ if (diskSpillWriter != null) {
+ return diskSpillWriter.getSpillReader();
+ } else {
+ return new PMemReaderForUnsafeExternalSorter(sortedArray, position, totalRecordsWritten, taskMetrics);
+ }
+ }
+
+ public void clearAll() {
+ freeAllPMemPages();
+ if (diskSpillWriter != null) {
+ diskSpillWriter.clearAll();
+ }
+ }
+
+ @Override
+ public int recordsSpilled() {
+ return numberOfRecordsToWritten;
+ }
+
+ @Override
+ public void freeAllPMemPages() {
+ for( MemoryBlock page: allocatedPMemPages) {
+ taskMemoryManager.freePMemPage(page, externalSorter);
+ }
+ allocatedPMemPages.clear();
+ }
+}
diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/SortedIteratorForSpills.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/SortedIteratorForSpills.java
new file mode 100644
index 00000000..5c2961fd
--- /dev/null
+++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/SortedIteratorForSpills.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.unsafe.UnsafeAlignedOffset;
+import org.apache.spark.unsafe.array.LongArray;
+
+public class SortedIteratorForSpills extends UnsafeSorterIterator {
+ private LongArray sortedArray;
+ private final int numRecords;
+ private int position;
+ private int offset;
+ private Object baseObject;
+ private long baseOffset;
+ private long keyPrefix;
+ private int recordLength;
+ private long currentPageNumber;
+ private final TaskContext taskContext = TaskContext.get();
+ private final TaskMemoryManager memoryManager;
+
+ /**
+ * Construct an iterator to read the spill.
+ * @param array the array here should be already sorted.
+ * @param numRecords the number of the records recorded in this LongArray
+ * @param offset
+ */
+ public SortedIteratorForSpills(
+ final TaskMemoryManager memoryManager,
+ LongArray array,
+ int numRecords,
+ int offset) {
+ this.memoryManager = memoryManager;
+ this.sortedArray = array;
+ this.numRecords = numRecords;
+ this.position = 0;
+ this.offset = offset;
+ }
+
+ public static SortedIteratorForSpills createFromExistingSorterIte(
+ UnsafeInMemorySorter.SortedIterator sortedIte,
+ UnsafeInMemorySorter inMemSorter) {
+ if (sortedIte == null) {
+ return null;
+ }
+ TaskMemoryManager taskMemoryManager = inMemSorter.getTaskMemoryManager();
+ LongArray array = inMemSorter.getLongArray();
+ int numberRecords = sortedIte.getNumRecords();
+ int offset = sortedIte.getOffset();
+ int position = sortedIte.getPosition();
+ SortedIteratorForSpills spillIte = new SortedIteratorForSpills(taskMemoryManager, array,numberRecords,offset);
+ spillIte.pointTo(position);
+ return spillIte;
+ }
+
+ @Override
+ public int getNumRecords() {
+ return numRecords;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return position / 2 < numRecords;
+ }
+
+ @Override
+ public void loadNext() {
+ // Kill the task in case it has been marked as killed. This logic is from
+ // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+ // to avoid performance overhead. This check is added here in `loadNext()` instead of in
+ // `hasNext()` because it's technically possible for the caller to be relying on
+ // `getNumRecords()` instead of `hasNext()` to know when to stop.
+ if (taskContext != null) {
+ taskContext.killTaskIfInterrupted();
+ }
+ loadPosition();
+ }
+ /**
+ * load the record of current position and move to end of the record.
+ */
+ private void loadPosition() {
+ // This pointer points to a 4-byte record length, followed by the record's bytes
+ final long recordPointer = sortedArray.get(offset + position);
+ currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer);
+ int uaoSize = UnsafeAlignedOffset.getUaoSize();
+ baseObject = memoryManager.getPage(recordPointer);
+ // Skip over record length
+ baseOffset = memoryManager.getOffsetInPage(recordPointer) + uaoSize;
+ recordLength = UnsafeAlignedOffset.getSize(baseObject, baseOffset - uaoSize);
+ keyPrefix = sortedArray.get(offset + position + 1);
+ position += 2;
+ }
+
+ /**
+ * point to a given position.
+ * @param pos
+ */
+ public void pointTo(int pos) {
+ if (pos % 2 != 0) {
+ throw new IllegalArgumentException("Can't point to the middle of a record.");
+ }
+ position = pos;
+ }
+
+ @Override
+ public Object getBaseObject() { return baseObject; }
+
+ @Override
+ public long getBaseOffset() { return baseOffset; }
+
+ public long getCurrentPageNumber() { return currentPageNumber; }
+
+ @Override
+ public int getRecordLength() { return recordLength; }
+
+ @Override
+ public long getKeyPrefix() { return keyPrefix; }
+
+ public LongArray getLongArray() { return sortedArray; }
+
+ public int getPosition() { return position; }
+}
diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/SortedPMemPageSpillWriter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/SortedPMemPageSpillWriter.java
new file mode 100644
index 00000000..bd8ca9c8
--- /dev/null
+++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/SortedPMemPageSpillWriter.java
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import com.intel.oap.common.unsafe.PersistentMemoryPlatform;
+
+import org.apache.spark.SparkEnv;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.UnsafeAlignedOffset;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.serializer.SerializerManager;
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class SortedPMemPageSpillWriter extends UnsafeSorterPMemSpillWriter {
+ private static final Logger sorted_logger = LoggerFactory.getLogger(SortedPMemPageSpillWriter.class);
+ private MemoryBlock currentPMemPage = null;
+ private long currentOffsetInPage = 0L;
+ private int currentNumOfRecordsInPage = 0;
+ //Page -> record number map
+ private LinkedHashMap pageNumOfRecMap = new LinkedHashMap();
+ private int numRecords = 0;
+ private int numRecordsOnPMem = 0;
+
+ private BlockManager blockManager;
+ private SerializerManager serializerManager;
+ private int fileBufferSize = 0;
+ private UnsafeSorterSpillWriter diskSpillWriter;
+
+ private final boolean pMemClflushEnabled = SparkEnv.get() != null &&
+ (boolean)SparkEnv.get().conf().get(package$.MODULE$.MEMORY_SPILL_PMEM_CLFLUSH_ENABLED());
+
+ public SortedPMemPageSpillWriter(
+ UnsafeExternalSorter externalSorter,
+ SortedIteratorForSpills sortedIterator,
+ int numberOfRecordsToWritten,
+ SerializerManager serializerManager,
+ BlockManager blockManager,
+ int fileBufferSize,
+ ShuffleWriteMetrics writeMetrics,
+ TaskMetrics taskMetrics) {
+ super(externalSorter, sortedIterator, numberOfRecordsToWritten, writeMetrics, taskMetrics);
+ this.blockManager = blockManager;
+ this.serializerManager = serializerManager;
+ this.fileBufferSize = fileBufferSize;
+ }
+
+ @Override
+ public void write() throws IOException {
+ boolean allBeWritten = writeToPMem();
+ if (!allBeWritten) {
+ sorted_logger.debug("No more PMEM space available. Write left spills to disk.");
+ writeToDisk();
+ }
+ }
+
+ /**
+ * @return if all records have been write to PMem, return true. Otherwise, return false.
+ * @throws IOException
+ */
+ private boolean writeToPMem() throws IOException {
+ while (sortedIterator.hasNext()) {
+ sortedIterator.loadNext();
+ final Object baseObject = sortedIterator.getBaseObject();
+ final long baseOffset = sortedIterator.getBaseOffset();
+ int curRecLen = sortedIterator.getRecordLength();
+ long curPrefix = sortedIterator.getKeyPrefix();
+ if (needNewPMemPage(curRecLen)) {
+ currentPMemPage = allocatePMemPage();
+ }
+ if (currentPMemPage != null) {
+ long pageBaseOffset = currentPMemPage.getBaseOffset();
+ long curPMemOffset = pageBaseOffset + currentOffsetInPage;
+ writeRecordToPMem(baseObject, baseOffset, curRecLen, curPrefix, curPMemOffset);
+ currentNumOfRecordsInPage ++;
+ pageNumOfRecMap.put(currentPMemPage, currentNumOfRecordsInPage);
+ numRecords ++;
+ } else {
+ //No more PMem space available, current loaded record can't be written to PMem.
+ return false;
+ }
+ }
+ //All records have been written to PMem.
+ return true;
+ }
+
+ private void writeToDisk() throws IOException{
+ int numOfRecLeft = numberOfRecordsToWritten - numRecordsOnPMem;
+ if (diskSpillWriter == null) {
+ diskSpillWriter = new UnsafeSorterSpillWriter(
+ blockManager,
+ fileBufferSize,
+ sortedIterator,
+ numOfRecLeft,
+ serializerManager,
+ writeMetrics,
+ taskMetrics);
+ }
+ diskSpillWriter.write(true);
+ sorted_logger.info("Num of rec {}; Num of rec written to PMem {}; still {} records left; num of rec written to disk {}.",
+ sortedIterator.getNumRecords(),
+ numRecordsOnPMem,
+ numOfRecLeft,
+ diskSpillWriter.recordsSpilled());
+ }
+
+ private boolean needNewPMemPage(int nextRecLen) {
+ if (allocatedPMemPages.isEmpty()) {
+ return true;
+ }
+ else {
+ long pageBaseOffset = currentPMemPage.getBaseOffset();
+ long leftLenInCurPage = currentPMemPage.size() - currentOffsetInPage;
+ int uaoSize = UnsafeAlignedOffset.getUaoSize();
+ long recSizeRequired = uaoSize + Long.BYTES + nextRecLen;
+ if (leftLenInCurPage < recSizeRequired) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private void writeRecordToPMem(Object baseObject, long baseOffset, int recLength, long prefix, long pMemOffset){
+ Platform.putInt(
+ null,
+ pMemOffset,
+ recLength);
+ int uaoSize = UnsafeAlignedOffset.getUaoSize();
+ long currentOffset = pMemOffset + uaoSize;
+ Platform.putLong(
+ null,
+ currentOffset,
+ prefix);
+ currentOffset += Long.BYTES;
+ PersistentMemoryPlatform.copyMemory(
+ baseObject,
+ baseOffset,
+ null,
+ currentOffset,
+ recLength,
+ pMemClflushEnabled);
+ numRecordsOnPMem ++;
+ currentOffsetInPage += uaoSize + Long.BYTES + recLength;
+ }
+
+ protected MemoryBlock allocatePMemPage() throws IOException{
+ currentPMemPage = super.allocatePMemPage();
+ currentOffsetInPage = 0;
+ currentNumOfRecordsInPage = 0;
+ return currentPMemPage;
+ }
+
+ @Override
+ public UnsafeSorterIterator getSpillReader() throws IOException {
+ return new SortedPMemPageSpillReader();
+ }
+
+ @Override
+ public void clearAll() {
+ freeAllPMemPages();
+ if (diskSpillWriter != null) {
+ diskSpillWriter.clearAll();
+ }
+ }
+
+ public int recordsSpilled() {
+ int recordsSpilledOnDisk = 0;
+ if (diskSpillWriter != null) {
+ recordsSpilledOnDisk = diskSpillWriter.recordsSpilled();
+ }
+ return numRecordsOnPMem + recordsSpilledOnDisk;
+ }
+
+ private class SortedPMemPageSpillReader extends UnsafeSorterIterator {
+ private final Logger sorted_reader_logger = LoggerFactory.getLogger(SortedPMemPageSpillReader.class);
+ private MemoryBlock curPage = null;
+ private int curPageIdx = -1;
+ private int curOffsetInPage = 0;
+ private int curNumOfRecInPage = 0;
+ private int curNumOfRec = 0;
+ private Object baseObject = null;
+ private long baseOffset = 0;
+ private int recordLength;
+ private long keyPrefix;
+ private UnsafeSorterIterator diskSpillReader;
+ private int numRecordsOnDisk = 0;
+
+ public SortedPMemPageSpillReader() throws IOException{
+ if (diskSpillWriter != null) {
+ diskSpillReader = diskSpillWriter.getSpillReader();
+ numRecordsOnDisk = diskSpillWriter.recordsSpilled();
+ }
+ }
+ @Override
+ public boolean hasNext() {
+ return curNumOfRec < numRecordsOnPMem + numRecordsOnDisk;
+ }
+ @Override
+ public void loadNext() throws IOException {
+ if(curNumOfRec < numRecordsOnPMem) {
+ loadNextOnPMem();
+ } else {
+ loadNextOnDisk();
+ }
+ }
+
+ private void loadNextOnPMem() throws IOException {
+ if (curPage == null || curNumOfRecInPage >= pageNumOfRecMap.get(curPage)) {
+ moveToNextPMemPage();
+ }
+ long curPageBaseOffset = curPage.getBaseOffset();
+ recordLength = UnsafeAlignedOffset.getSize(null, curPageBaseOffset + curOffsetInPage);
+ curOffsetInPage += UnsafeAlignedOffset.getUaoSize();
+ keyPrefix = Platform.getLong(null, curPageBaseOffset + curOffsetInPage);
+ curOffsetInPage += Long.BYTES;
+ baseOffset = curPageBaseOffset + curOffsetInPage;
+ curOffsetInPage += recordLength;
+ curNumOfRecInPage ++;
+ curNumOfRec ++;
+ }
+
+ private void loadNextOnDisk() throws IOException {
+ if (diskSpillReader != null && diskSpillReader.hasNext()) {
+ diskSpillReader.loadNext();
+ baseObject = diskSpillReader.getBaseObject();
+ baseOffset = diskSpillReader.getBaseOffset();
+ recordLength = diskSpillReader.getRecordLength();
+ keyPrefix = diskSpillReader.getKeyPrefix();
+ curNumOfRec ++;
+ }
+ }
+
+ private void moveToNextPMemPage() {
+ curPageIdx++;
+ curPage = allocatedPMemPages.get(curPageIdx);
+ curOffsetInPage = 0;
+ curNumOfRecInPage = 0;
+ }
+
+ @Override
+ public Object getBaseObject() {
+ return baseObject;
+ }
+
+ @Override
+ public long getBaseOffset() {
+ return baseOffset;
+ }
+
+ @Override
+ public int getRecordLength() {
+ return recordLength;
+ }
+
+ @Override
+ public long getKeyPrefix() {
+ return keyPrefix;
+ }
+
+ @Override
+ public int getNumRecords() {
+ return numRecordsOnPMem + numRecordsOnDisk;
+ }
+ }
+}
diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/SpillWriterForUnsafeSorter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/SpillWriterForUnsafeSorter.java
new file mode 100644
index 00000000..06a670e8
--- /dev/null
+++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/SpillWriterForUnsafeSorter.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.io.IOException;
+
+public interface SpillWriterForUnsafeSorter {
+ //write down all the spills.
+ public void write() throws IOException;
+
+ //get reader for the spill maintained by this writer.
+ public UnsafeSorterIterator getSpillReader() throws IOException;
+
+ //clear all acquired resource after read is done.
+ public void clearAll();
+
+ //get spilled record number.
+ public int recordsSpilled();
+}
diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 6b587ac0..9ecb6e4e 100644
--- a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -25,20 +25,17 @@
import java.util.function.Supplier;
import com.google.common.annotations.VisibleForTesting;
-
-import com.intel.oap.common.storage.stream.ChunkInputStream;
-import com.intel.oap.common.storage.stream.DataStore;
-
+import org.apache.spark.SparkEnv;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.memory.SparkOutOfMemoryError;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.SparkEnv;
+import org.apache.spark.internal.config.package$;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.internal.config.package$;
import org.apache.spark.memory.MemoryConsumer;
-import org.apache.spark.memory.PMemManagerInitializer;
-import org.apache.spark.memory.SparkOutOfMemoryError;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TooLargePageException;
import org.apache.spark.serializer.SerializerManager;
@@ -80,7 +77,12 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
* Force this sorter to spill when there are this many elements in memory.
*/
private final int numElementsForSpillThreshold;
-
+ private final boolean spillToPMemEnabled = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get(
+ package$.MODULE$.MEMORY_SPILL_PMEM_ENABLED());
+ /**
+ * spillWriterType
+ */
+ private String spillWriterType = null;
/**
* Memory pages that hold the records being sorted. The pages in this list are freed when
* spilling, although in principle we could recycle these pages across spills (on the other hand,
@@ -89,7 +91,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
*/
private final LinkedList allocatedPages = new LinkedList<>();
- private final LinkedList spillWriters = new LinkedList<>();
+ private final LinkedList spillWriters = new LinkedList<>();
// These variables are reset after spilling:
@Nullable private volatile UnsafeInMemorySorter inMemSorter;
@@ -159,7 +161,12 @@ private UnsafeExternalSorter(
// Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
// this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024
this.fileBufferSizeBytes = 32 * 1024;
-
+ SparkEnv sparkEnv = SparkEnv.get();
+ if (sparkEnv != null && sparkEnv.conf() != null){
+ this.spillWriterType = sparkEnv.conf().get(package$.MODULE$.USAFE_EXTERNAL_SORTER_SPILL_WRITE_TYPE());
+ } else {
+ this.spillWriterType = PMemSpillWriterType.WRITE_SORTED_RECORDS_TO_PMEM.toString();
+ }
if (existingInMemorySorter == null) {
RecordComparator comparator = null;
if (recordComparatorSupplier != null) {
@@ -214,34 +221,52 @@ public long spill(long size, MemoryConsumer trigger) throws IOException {
}
logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
- Thread.currentThread().getId(),
- Utils.bytesToString(getMemoryUsage()),
- spillWriters.size(),
- spillWriters.size() > 1 ? " times" : " time");
-
+ Thread.currentThread().getId(),
+ Utils.bytesToString(getMemoryUsage()),
+ spillWriters.size() ,
+ spillWriters.size() > 1 ? " times" : " time");
ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
-
- final UnsafeSorterSpillWriter spillWriter =
- new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
- inMemSorter.numRecords());
- spillWriters.add(spillWriter);
- spillIterator(inMemSorter.getSortedIterator(), spillWriter);
-
+ // Sorting records or not will be handled by different spill writer, here null is given instead.
+ spillWithWriter(null, inMemSorter.numRecords(), writeMetrics, false);
final long spillSize = freeMemory();
+ inMemSorter.reset();
// Note that this is more-or-less going to be a multiple of the page size, so wasted space in
// pages will currently be counted as memory spilled even though that space isn't actually
// written to disk. This also counts the space needed to store the sorter's pointer array.
- inMemSorter.reset();
// Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
// records. Otherwise, if the task is over allocated memory, then without freeing the memory
// pages, we might not be able to get memory for the pointer array.
-
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten());
totalSpillBytes += spillSize;
return spillSize;
}
+ //Todo: It's confusing to pass in ShuffleWriteMetrics here. Will reconsider and fix it later
+ public SpillWriterForUnsafeSorter spillWithWriter(
+ UnsafeSorterIterator sortedIterator,
+ int numberOfRecordsToWritten,
+ ShuffleWriteMetrics writeMetrics,
+ boolean isSorted) throws IOException {
+ PMemSpillWriterType writerType = PMemSpillWriterType.valueOf(spillWriterType);
+ logger.info("PMemSpillWriterType:{}",writerType.toString());
+ final SpillWriterForUnsafeSorter spillWriter = PMemSpillWriterFactory.getSpillWriter(
+ writerType,
+ this,
+ sortedIterator,
+ numberOfRecordsToWritten,
+ serializerManager,
+ blockManager,
+ fileBufferSizeBytes,
+ writeMetrics,
+ taskContext.taskMetrics(),
+ spillToPMemEnabled,
+ isSorted);
+ spillWriter.write();
+ spillWriters.add(spillWriter);
+ return spillWriter;
+ }
+
/**
* Return the total memory usage of this sorter, including the data pages and the sorter's pointer
* array.
@@ -292,6 +317,9 @@ public int getNumberOfAllocatedPages() {
return allocatedPages.size();
}
+ public LinkedList getAllocatedPages() {
+ return this.allocatedPages;
+ }
/**
* Free this sorter's data pages.
*
@@ -310,41 +338,12 @@ private long freeMemory() {
return memoryFreed;
}
- /**
- * Deletes any spill files created by this sorter.
- */
- private void deleteSpillFiles() {
- boolean pMemSpillEnabled = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get(
- package$.MODULE$.MEMORY_SPILL_PMEM_ENABLED());
- if (pMemSpillEnabled == true)
- deletePMemSpillFiles();
- else
- deleteDiskSpillFiles();
- }
- private void deleteDiskSpillFiles() {
- while (!spillWriters.isEmpty()) {
- File file = spillWriters.removeFirst().getFile();
- if (file != null && file.exists()) {
- if (!file.delete()) {
- logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
- }
- }
- }
- }
-
- private void deletePMemSpillFiles() {
- while (!spillWriters.isEmpty()) {
- File file = spillWriters.removeFirst().getFile();
- try {
- ChunkInputStream cis = ChunkInputStream.getChunkInputStreamInstance(file.toString(),
- new DataStore(PMemManagerInitializer.getPMemManager(),
- PMemManagerInitializer.getProperties()));
- cis.free();
- } catch (IOException e) {
- logger.debug(e.toString());
- }
+ private void freeSpills() {
+ for (SpillWriterForUnsafeSorter spillWriter : spillWriters) {
+ spillWriter.clearAll();
}
+ spillWriters.clear();
}
/**
@@ -352,7 +351,7 @@ private void deletePMemSpillFiles() {
*/
public void cleanupResources() {
synchronized (this) {
- deleteSpillFiles();
+ freeSpills();
freeMemory();
if (inMemSorter != null) {
inMemSorter.free();
@@ -470,7 +469,6 @@ public void insertKVRecord(Object keyBase, long keyOffset, int keyLen,
pageCursor += keyLen;
Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen);
pageCursor += valueLen;
-
assert(inMemSorter != null);
inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
}
@@ -499,8 +497,8 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
} else {
final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger(
recordComparatorSupplier.get(), prefixComparator, spillWriters.size());
- for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
- spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager));
+ for (SpillWriterForUnsafeSorter spillWriter: spillWriters) {
+ spillMerger.addSpillIfNotEmpty(spillWriter.getSpillReader());
}
if (inMemSorter != null) {
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
@@ -509,23 +507,12 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
return spillMerger.getSortedIterator();
}
}
+ public UnsafeInMemorySorter getInMemSorter() { return inMemSorter; }
@VisibleForTesting boolean hasSpaceForAnotherRecord() {
return inMemSorter.hasSpaceForAnotherRecord();
}
- private static void spillIterator(UnsafeSorterIterator inMemIterator,
- UnsafeSorterSpillWriter spillWriter) throws IOException {
- while (inMemIterator.hasNext()) {
- inMemIterator.loadNext();
- final Object baseObject = inMemIterator.getBaseObject();
- final long baseOffset = inMemIterator.getBaseOffset();
- final int recordLength = inMemIterator.getRecordLength();
- spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix());
- }
- spillWriter.close();
- }
-
/**
* An UnsafeSorterIterator that support spilling.
*/
@@ -552,19 +539,14 @@ public long spill() throws IOException {
&& numRecords > 0)) {
return 0L;
}
-
UnsafeInMemorySorter.SortedIterator inMemIterator =
- ((UnsafeInMemorySorter.SortedIterator) upstream).clone();
-
- ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
- // Iterate over the records that have not been returned and spill them.
- final UnsafeSorterSpillWriter spillWriter =
- new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords);
- spillIterator(inMemIterator, spillWriter);
- spillWriters.add(spillWriter);
- nextUpstream = spillWriter.getReader(serializerManager);
-
+ ((UnsafeInMemorySorter.SortedIterator) upstream).clone();
+
+ ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
long released = 0L;
+ SpillWriterForUnsafeSorter spillWriter = spillWithWriter(inMemIterator, numRecords, writeMetrics, true);
+ nextUpstream = spillWriter.getSpillReader();
+
synchronized (UnsafeExternalSorter.this) {
// release the pages except the one that is used. There can still be a caller that
// is accessing the current record. We free this page in that caller's next loadNext()
@@ -668,9 +650,9 @@ public UnsafeSorterIterator getIterator(int startIndex) throws IOException {
} else {
LinkedList queue = new LinkedList<>();
int i = 0;
- for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
+ for (SpillWriterForUnsafeSorter spillWriter : spillWriters) {
if (i + spillWriter.recordsSpilled() > startIndex) {
- UnsafeSorterIterator iter = spillWriter.getReader(serializerManager);
+ UnsafeSorterIterator iter = spillWriter.getSpillReader();
moveOver(iter, startIndex - i);
queue.add(iter);
}
diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
new file mode 100644
index 00000000..b1c1d393
--- /dev/null
+++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -0,0 +1,435 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import java.util.Comparator;
+import java.util.LinkedList;
+
+import org.apache.avro.reflect.Nullable;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.SparkOutOfMemoryError;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.UnsafeAlignedOffset;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.util.collection.Sorter;
+
+/**
+ * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
+ * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm
+ * compares records, it will first compare the stored key prefixes; if the prefixes are not equal,
+ * then we do not need to traverse the record pointers to compare the actual records. Avoiding these
+ * random memory accesses improves cache hit rates.
+ */
+public final class UnsafeInMemorySorter {
+
+ private static final class SortComparator implements Comparator {
+
+ private final RecordComparator recordComparator;
+ private final PrefixComparator prefixComparator;
+ private final TaskMemoryManager memoryManager;
+
+ SortComparator(
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ TaskMemoryManager memoryManager) {
+ this.recordComparator = recordComparator;
+ this.prefixComparator = prefixComparator;
+ this.memoryManager = memoryManager;
+ }
+
+ @Override
+ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
+ final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix);
+ int uaoSize = UnsafeAlignedOffset.getUaoSize();
+ if (prefixComparisonResult == 0) {
+ final Object baseObject1 = memoryManager.getPage(r1.recordPointer);
+ final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize;
+ final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize);
+ final Object baseObject2 = memoryManager.getPage(r2.recordPointer);
+ final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize;
+ final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize);
+ return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2,
+ baseOffset2, baseLength2);
+ } else {
+ return prefixComparisonResult;
+ }
+ }
+ }
+
+ private final MemoryConsumer consumer;
+ private final TaskMemoryManager memoryManager;
+ @Nullable
+ private final Comparator sortComparator;
+
+ /**
+ * If non-null, specifies the radix sort parameters and that radix sort will be used.
+ */
+ @Nullable
+ private final PrefixComparators.RadixSortSupport radixSortSupport;
+
+ /**
+ * Within this buffer, position {@code 2 * i} holds a pointer to the record at
+ * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
+ *
+ * Only part of the array will be used to store the pointers, the rest part is preserved as
+ * temporary buffer for sorting.
+ */
+ private LongArray array;
+
+ /**
+ * The position in the sort buffer where new records can be inserted.
+ */
+ private int pos = 0;
+
+ /**
+ * If sorting with radix sort, specifies the starting position in the sort buffer where records
+ * with non-null prefixes are kept. Positions [0..nullBoundaryPos) will contain null-prefixed
+ * records, and positions [nullBoundaryPos..pos) non-null prefixed records. This lets us avoid
+ * radix sorting over null values.
+ */
+ private int nullBoundaryPos = 0;
+
+ /*
+ * How many records could be inserted, because part of the array should be left for sorting.
+ */
+ private int usableCapacity = 0;
+
+ private long initialSize;
+
+ private long totalSortTimeNanos = 0L;
+
+ public UnsafeInMemorySorter(
+ final MemoryConsumer consumer,
+ final TaskMemoryManager memoryManager,
+ final RecordComparator recordComparator,
+ final PrefixComparator prefixComparator,
+ int initialSize,
+ boolean canUseRadixSort) {
+ this(consumer, memoryManager, recordComparator, prefixComparator,
+ consumer.allocateArray(initialSize * 2L), canUseRadixSort);
+ }
+
+ public UnsafeInMemorySorter(
+ final MemoryConsumer consumer,
+ final TaskMemoryManager memoryManager,
+ final RecordComparator recordComparator,
+ final PrefixComparator prefixComparator,
+ LongArray array,
+ boolean canUseRadixSort) {
+ this.consumer = consumer;
+ this.memoryManager = memoryManager;
+ this.initialSize = array.size();
+ if (recordComparator != null) {
+ this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+ if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) {
+ this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator;
+ } else {
+ this.radixSortSupport = null;
+ }
+ } else {
+ this.sortComparator = null;
+ this.radixSortSupport = null;
+ }
+ this.array = array;
+ this.usableCapacity = getUsableCapacity();
+ }
+
+ private int getUsableCapacity() {
+ // Radix sort requires same amount of used memory as buffer, Tim sort requires
+ // half of the used memory as buffer.
+ return (int) (array.size() / (radixSortSupport != null ? 2 : 1.5));
+ }
+
+ /**
+ * Free the memory used by pointer array.
+ */
+ public void freeWithoutLongArray() {
+ if (consumer != null) {
+ array = null;
+ }
+ }
+
+ public void free() {
+ if (consumer != null) {
+ if (array != null) {
+ consumer.freeArray(array);
+ }
+ array = null;
+ }
+ }
+
+ public void resetWithoutLongArrray() {
+ if (consumer != null) {
+ // the call to consumer.allocateArray may trigger a spill which in turn access this instance
+ // and eventually re-enter this method and try to free the array again. by setting the array
+ // to null and its length to 0 we effectively make the spill code-path a no-op. setting the
+ // array to null also indicates that it has already been de-allocated which prevents a double
+ // de-allocation in free().
+ array = null;
+ usableCapacity = 0;
+ pos = 0;
+ nullBoundaryPos = 0;
+ array = consumer.allocateArray(initialSize);
+ usableCapacity = getUsableCapacity();
+ }
+ pos = 0;
+ nullBoundaryPos = 0;
+ }
+
+ public void reset() {
+ if (consumer != null) {
+ consumer.freeArray(array);
+ // the call to consumer.allocateArray may trigger a spill which in turn access this instance
+ // and eventually re-enter this method and try to free the array again. by setting the array
+ // to null and its length to 0 we effectively make the spill code-path a no-op. setting the
+ // array to null also indicates that it has already been de-allocated which prevents a double
+ // de-allocation in free().
+ array = null;
+ usableCapacity = 0;
+ pos = 0;
+ nullBoundaryPos = 0;
+ array = consumer.allocateArray(initialSize);
+ usableCapacity = getUsableCapacity();
+ }
+ pos = 0;
+ nullBoundaryPos = 0;
+ }
+
+ /**
+ * @return the number of records that have been inserted into this sorter.
+ */
+ public int numRecords() {
+ return pos / 2;
+ }
+
+ /**
+ * @return the total amount of time spent sorting data (in-memory only).
+ */
+ public long getSortTimeNanos() {
+ return totalSortTimeNanos;
+ }
+
+ public long getMemoryUsage() {
+ if (array == null) {
+ return 0L;
+ }
+
+ return array.size() * 8;
+ }
+
+ public LongArray getSortedArray() {
+ getSortedIterator();
+ return array;
+ }
+
+ public LongArray getArray() {
+ return array;
+ }
+
+ public boolean hasSpaceForAnotherRecord() {
+ return pos + 1 < usableCapacity;
+ }
+
+ public void expandPointerArray(LongArray newArray) {
+ if (newArray.size() < array.size()) {
+ // checkstyle.off: RegexpSinglelineJava
+ throw new SparkOutOfMemoryError("Not enough memory to grow pointer array");
+ // checkstyle.on: RegexpSinglelineJava
+ }
+ Platform.copyMemory(
+ array.getBaseObject(),
+ array.getBaseOffset(),
+ newArray.getBaseObject(),
+ newArray.getBaseOffset(),
+ pos * 8L);
+ consumer.freeArray(array);
+ array = newArray;
+ usableCapacity = getUsableCapacity();
+ }
+
+ /**
+ * Inserts a record to be sorted. Assumes that the record pointer points to a record length
+ * stored as a uaoSize(4 or 8) bytes integer, followed by the record's bytes.
+ *
+ * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
+ * @param keyPrefix a user-defined key prefix
+ */
+ public void insertRecord(long recordPointer, long keyPrefix, boolean prefixIsNull) {
+ if (!hasSpaceForAnotherRecord()) {
+ throw new IllegalStateException("There is no space for new record");
+ }
+ if (prefixIsNull && radixSortSupport != null) {
+ // Swap forward a non-null record to make room for this one at the beginning of the array.
+ array.set(pos, array.get(nullBoundaryPos));
+ pos++;
+ array.set(pos, array.get(nullBoundaryPos + 1));
+ pos++;
+ // Place this record in the vacated position.
+ array.set(nullBoundaryPos, recordPointer);
+ nullBoundaryPos++;
+ array.set(nullBoundaryPos, keyPrefix);
+ nullBoundaryPos++;
+ } else {
+ array.set(pos, recordPointer);
+ pos++;
+ array.set(pos, keyPrefix);
+ pos++;
+ }
+ }
+
+ public final class SortedIterator extends UnsafeSorterIterator implements Cloneable {
+
+ private final int numRecords;
+ private int position;
+ private int offset;
+ private Object baseObject;
+ private long baseOffset;
+ private long keyPrefix;
+ private long currentRecordPointer;
+ private int recordLength;
+ private long currentPageNumber;
+ private final TaskContext taskContext = TaskContext.get();
+
+ private SortedIterator(int numRecords, int offset) {
+ this.numRecords = numRecords;
+ this.position = 0;
+ this.offset = offset;
+ }
+
+ public SortedIterator clone() {
+ SortedIterator iter = new SortedIterator(numRecords, offset);
+ iter.position = position;
+ iter.baseObject = baseObject;
+ iter.baseOffset = baseOffset;
+ iter.keyPrefix = keyPrefix;
+ iter.recordLength = recordLength;
+ iter.currentPageNumber = currentPageNumber;
+ return iter;
+ }
+
+ @Override
+ public int getNumRecords() {
+ return numRecords;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return position / 2 < numRecords;
+ }
+
+ @Override
+ public void loadNext() {
+ // Kill the task in case it has been marked as killed. This logic is from
+ // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order
+ // to avoid performance overhead. This check is added here in `loadNext()` instead of in
+ // `hasNext()` because it's technically possible for the caller to be relying on
+ // `getNumRecords()` instead of `hasNext()` to know when to stop.
+ if (taskContext != null) {
+ taskContext.killTaskIfInterrupted();
+ }
+ // This pointer points to a 4-byte record length, followed by the record's bytes
+ final long recordPointer = array.get(offset + position);
+ currentRecordPointer = recordPointer;
+ currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer);
+ int uaoSize = UnsafeAlignedOffset.getUaoSize();
+ baseObject = memoryManager.getPage(recordPointer);
+ // Skip over record length
+ baseOffset = memoryManager.getOffsetInPage(recordPointer) + uaoSize;
+ recordLength = UnsafeAlignedOffset.getSize(baseObject, baseOffset - uaoSize);
+ keyPrefix = array.get(offset + position + 1);
+ position += 2;
+ }
+
+ @Override
+ public Object getBaseObject() { return baseObject; }
+
+ @Override
+ public long getBaseOffset() { return baseOffset; }
+
+ public long getCurrentPageNumber() {
+ return currentPageNumber;
+ }
+
+ @Override
+ public int getRecordLength() { return recordLength; }
+
+ @Override
+ public long getKeyPrefix() { return keyPrefix; }
+
+ public int getPosition() { return position; }
+
+ public int getOffset() { return offset; }
+
+ public long getCurrentRecordPointer() { return currentRecordPointer; }
+ }
+
+ /**
+ * Return an iterator over record pointers in sorted order. For efficiency, all calls to
+ * {@code next()} will return the same mutable object.
+ */
+ public UnsafeSorterIterator getSortedIterator() {
+ int offset = 0;
+ long start = System.nanoTime();
+ if (sortComparator != null) {
+ if (this.radixSortSupport != null) {
+ offset = RadixSort.sortKeyPrefixArray(
+ array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7,
+ radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
+ } else {
+ MemoryBlock unused = new MemoryBlock(
+ array.getBaseObject(),
+ array.getBaseOffset() + pos * 8L,
+ (array.size() - pos) * 8L);
+ LongArray buffer = new LongArray(unused);
+ Sorter sorter =
+ new Sorter<>(new UnsafeSortDataFormat(buffer));
+ sorter.sort(array, 0, pos / 2, sortComparator);
+ }
+ }
+ totalSortTimeNanos += System.nanoTime() - start;
+ if (nullBoundaryPos > 0) {
+ assert radixSortSupport != null : "Nulls are only stored separately with radix sort";
+ LinkedList queue = new LinkedList<>();
+
+ // The null order is either LAST or FIRST, regardless of sorting direction (ASC|DESC)
+ if (radixSortSupport.nullsFirst()) {
+ queue.add(new SortedIterator(nullBoundaryPos / 2, 0));
+ queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset));
+ } else {
+ queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset));
+ queue.add(new SortedIterator(nullBoundaryPos / 2, 0));
+ }
+ return new UnsafeExternalSorter.ChainedIterator(queue);
+ } else {
+ return new SortedIterator(pos / 2, offset);
+ }
+ }
+
+ public LongArray getLongArray() {
+ return array;
+ }
+
+ public TaskMemoryManager getTaskMemoryManager() {
+ return memoryManager;
+ }
+
+}
diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterPMemSpillWriter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterPMemSpillWriter.java
new file mode 100644
index 00000000..becab97d
--- /dev/null
+++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterPMemSpillWriter.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+import java.io.IOException;
+import java.util.LinkedList;
+
+public abstract class UnsafeSorterPMemSpillWriter implements SpillWriterForUnsafeSorter{
+ /**
+ * the memConsumer used to allocate pmem pages
+ */
+ protected UnsafeExternalSorter externalSorter;
+
+ protected SortedIteratorForSpills sortedIterator;
+
+ protected int numberOfRecordsToWritten = 0;
+
+ protected TaskMemoryManager taskMemoryManager;
+
+ //Todo: It's confusing to have ShuffleWriteMetrics here. will reconsider and fix it later.
+ protected ShuffleWriteMetrics writeMetrics;
+
+ protected TaskMetrics taskMetrics;
+
+ protected LinkedList allocatedPMemPages = new LinkedList();
+
+ //Page size in bytes.
+ private static long DEFAULT_PAGE_SIZE = 64*1024*1024;
+
+ public UnsafeSorterPMemSpillWriter(
+ UnsafeExternalSorter externalSorter,
+ SortedIteratorForSpills sortedIterator,
+ int numberOfRecordsToWritten,
+ ShuffleWriteMetrics writeMetrics,
+ TaskMetrics taskMetrics) {
+ this.externalSorter = externalSorter;
+ this.taskMemoryManager = externalSorter.getTaskMemoryManager();
+ this.sortedIterator = sortedIterator;
+ this.numberOfRecordsToWritten = numberOfRecordsToWritten;
+ this.writeMetrics = writeMetrics;
+ this.taskMetrics = taskMetrics;
+ }
+
+ protected MemoryBlock allocatePMemPage() throws IOException { ;
+ return allocatePMemPage(DEFAULT_PAGE_SIZE);
+ }
+
+ protected MemoryBlock allocatePMemPage(long size) {
+ MemoryBlock page = taskMemoryManager.allocatePage(size, externalSorter, true);
+ if (page != null) {
+ allocatedPMemPages.add(page);
+ }
+ return page;
+ }
+
+ protected void freeAllPMemPages() {
+ for (MemoryBlock page : allocatedPMemPages) {
+ taskMemoryManager.freePage(page, externalSorter);
+ }
+ allocatedPMemPages.clear();
+ }
+ public abstract UnsafeSorterIterator getSpillReader() throws IOException;
+}
diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 32c2d79a..25ca11fb 100644
--- a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -19,15 +19,13 @@
import com.google.common.io.ByteStreams;
import com.google.common.io.Closeables;
-import com.intel.oap.common.storage.stream.ChunkInputStream;
-import com.intel.oap.common.storage.stream.DataStore;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
+import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.internal.config.package$;
import org.apache.spark.internal.config.ConfigEntry;
import org.apache.spark.io.NioBufferedFileInputStream;
import org.apache.spark.io.ReadAheadInputStream;
-import org.apache.spark.memory.PMemManagerInitializer;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId;
import org.apache.spark.unsafe.Platform;
@@ -38,26 +36,34 @@
* Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
* of the file format).
*/
-public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable {
+public class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable {
public static final int MAX_BUFFER_SIZE_BYTES = 16777216; // 16 mb
- private InputStream in;
- private DataInputStream din;
+ protected InputStream in;
+ protected DataInputStream din;
// Variables that change with every record read:
- private int recordLength;
- private long keyPrefix;
- private int numRecords;
- private int numRecordsRemaining;
-
- private byte[] arr = new byte[1024 * 1024];
- private Object baseObject = arr;
- private final TaskContext taskContext = TaskContext.get();
+ protected int recordLength;
+ protected long keyPrefix;
+ protected int numRecords;
+ protected int numRecordsRemaining;
+
+ protected byte[] arr = new byte[1024 * 1024];
+ protected Object baseObject = arr;
+ protected final TaskContext taskContext = TaskContext.get();
+ protected final TaskMetrics taskMetrics;
+
+ public UnsafeSorterSpillReader(TaskMetrics taskMetrics) {
+ this.taskMetrics = taskMetrics;
+ }
public UnsafeSorterSpillReader(
SerializerManager serializerManager,
+ TaskMetrics taskMetrics,
File file,
BlockId blockId) throws IOException {
+ assert (file.length() > 0);
+ this.taskMetrics = taskMetrics;
final ConfigEntry