diff --git a/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/src/main/java/org/apache/spark/memory/MemoryConsumer.java new file mode 100644 index 00000000..12e360ac --- /dev/null +++ b/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import java.io.IOException; + +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * A memory consumer of {@link TaskMemoryManager} that supports spilling. + * + * Note: this only supports allocation / spilling of Tungsten memory. + */ +public abstract class MemoryConsumer { + + protected final TaskMemoryManager taskMemoryManager; + private final long pageSize; + private final MemoryMode mode; + protected long used; + + protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize, MemoryMode mode) { + this.taskMemoryManager = taskMemoryManager; + this.pageSize = pageSize; + this.mode = mode; + } + + protected MemoryConsumer(TaskMemoryManager taskMemoryManager) { + this(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP); + } + + /** + * Returns the memory mode, {@link MemoryMode#ON_HEAP} or {@link MemoryMode#OFF_HEAP}. + */ + public MemoryMode getMode() { + return mode; + } + + /** + * Returns the size of used memory in bytes. + */ + public long getUsed() { + return used; + } + + /** + * Force spill during building. + */ + public void spill() throws IOException { + spill(Long.MAX_VALUE, this); + } + + /** + * Spill some data to disk to release memory, which will be called by TaskMemoryManager + * when there is not enough memory for the task. + * + * This should be implemented by subclass. + * + * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill(). + * + * Note: today, this only frees Tungsten-managed pages. + * + * @param size the amount of memory should be released + * @param trigger the MemoryConsumer that trigger this spilling + * @return the amount of released memory in bytes + */ + public abstract long spill(long size, MemoryConsumer trigger) throws IOException; + + /** + * Allocates a LongArray of `size`. Note that this method may throw `SparkOutOfMemoryError` + * if Spark doesn't have enough memory for this allocation, or throw `TooLargePageException` + * if this `LongArray` is too large to fit in a single page. The caller side should take care of + * these two exceptions, or make sure the `size` is small enough that won't trigger exceptions. + * + * @throws SparkOutOfMemoryError + * @throws TooLargePageException + */ + public LongArray allocateArray(long size) { + long required = size * 8L; + MemoryBlock page = taskMemoryManager.allocatePage(required, this); + if (page == null || page.size() < required) { + throwOom(page, required); + } + used += required; + return new LongArray(page); + } + + /** + * Frees a LongArray. + */ + public void freeArray(LongArray array) { + freePage(array.memoryBlock()); + } + + /** + * Allocate a memory block with at least `required` bytes. + * + * @throws SparkOutOfMemoryError + */ + protected MemoryBlock allocatePage(long required) { + MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); + if (page == null || page.size() < required) { + throwOom(page, required); + } + used += page.size(); + return page; + } + + /** + * Free a memory block. + */ + protected void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } + + /** + * Allocates memory of `size`. + */ + public long acquireMemory(long size) { + long granted = taskMemoryManager.acquireExecutionMemory(size, this); + used += granted; + return granted; + } + + /** + * Release N bytes of memory. + */ + public void freeMemory(long size) { + taskMemoryManager.releaseExecutionMemory(size, this); + used -= size; + } + + private void throwOom(final MemoryBlock page, final long required) { + long got = 0; + if (page != null) { + got = page.size(); + taskMemoryManager.freePage(page, this); + } + taskMemoryManager.showMemoryUsage(); + // checkstyle.off: RegexpSinglelineJava + throw new SparkOutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + + got); + // checkstyle.on: RegexpSinglelineJava + } + + public TaskMemoryManager getTaskMemoryManager() { + return taskMemoryManager; + } +} diff --git a/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/src/main/java/org/apache/spark/memory/TaskMemoryManager.java new file mode 100644 index 00000000..6bbe33e0 --- /dev/null +++ b/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -0,0 +1,619 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import javax.annotation.concurrent.GuardedBy; +import java.io.IOException; +import java.nio.channels.ClosedByInterruptException; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.Utils; + +/** + * Manages the memory allocated by an individual task. + *

+ * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs. + * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is + * addressed by the combination of a base Object reference and a 64-bit offset within that object. + * This is a problem when we want to store pointers to data structures inside of other structures, + * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits + * to address memory, we can't just store the address of the base object since it's not guaranteed + * to remain stable as the heap gets reorganized due to GC. + *

+ * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap + * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to + * store a "page number" and the lower 51 bits to store an offset within this page. These page + * numbers are used to index into a "page table" array inside of the MemoryManager in order to + * retrieve the base object. + *

+ * This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the + * maximum size of a long[] array, allowing us to address 8192 * (2^31 - 1) * 8 bytes, which is + * approximately 140 terabytes of memory. + */ +public class TaskMemoryManager { + + private static final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); + + /** The number of bits used to address the page table. */ + private static final int PAGE_NUMBER_BITS = 13; + + /** The number of bits used to encode offsets in data pages. */ + @VisibleForTesting + static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51 + + /** The number of entries in the page table. */ + private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; + + /** + * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is + * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's + * maximum page size is limited by the maximum amount of data that can be stored in a long[] + * array, which is (2^31 - 1) * 8 bytes (or about 17 gigabytes). Therefore, we cap this at 17 + * gigabytes. + */ + public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; + + /** + * Similar to an operating system's page table, this array maps page numbers into base object + * pointers, allowing us to translate between the hashtable's internal 64-bit address + * representation and the baseObject+offset representation which we use to support both on- and + * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`. + * When using an on-heap allocator, the entries in this map will point to pages' base objects. + * Entries are added to this map as new data pages are allocated. + */ + private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE]; + + /** + * Bitmap for tracking free pages. + */ + private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); + + private final MemoryManager memoryManager; + + private final long taskAttemptId; + + /** + * Tracks whether we're on-heap or off-heap. For off-heap, we short-circuit most of these methods + * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, + * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. + */ + final MemoryMode tungstenMemoryMode; + + /** + * Tracks spillable memory consumers. + */ + @GuardedBy("this") + private final HashSet consumers; + + /** + * The amount of memory that is acquired but not used. + */ + private volatile long acquiredButNotUsed = 0L; + + /** + * Construct a new TaskMemoryManager. + */ + public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { + this.tungstenMemoryMode = memoryManager.tungstenMemoryMode(); + this.memoryManager = memoryManager; + this.taskAttemptId = taskAttemptId; + this.consumers = new HashSet<>(); + } + + /** + * Acquire N bytes of memory for a consumer. If there is no enough memory, it will call + * spill() of consumers to release more memory. + * + * @return number of bytes successfully granted (<= N). + */ + public long acquireExecutionMemory(long required, MemoryConsumer consumer) { + assert(required >= 0); + assert(consumer != null); + MemoryMode mode = consumer.getMode(); + // If we are allocating Tungsten pages off-heap and receive a request to allocate on-heap + // memory here, then it may not make sense to spill since that would only end up freeing + // off-heap memory. This is subject to change, though, so it may be risky to make this + // optimization now in case we forget to undo it late when making changes. + synchronized (this) { + long got = memoryManager.acquireExecutionMemory(required, taskAttemptId, mode); + + // Try to release memory from other consumers first, then we can reduce the frequency of + // spilling, avoid to have too many spilled files. + if (got < required) { + // Call spill() on other consumers to release memory + // Sort the consumers according their memory usage. So we avoid spilling the same consumer + // which is just spilled in last few times and re-spilling on it will produce many small + // spill files. + TreeMap> sortedConsumers = new TreeMap<>(); + for (MemoryConsumer c: consumers) { + if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { + long key = c.getUsed(); + List list = + sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); + list.add(c); + } + } + while (!sortedConsumers.isEmpty()) { + // Get the consumer using the least memory more than the remaining required memory. + Map.Entry> currentEntry = + sortedConsumers.ceilingEntry(required - got); + // No consumer has used memory more than the remaining required memory. + // Get the consumer of largest used memory. + if (currentEntry == null) { + currentEntry = sortedConsumers.lastEntry(); + } + List cList = currentEntry.getValue(); + MemoryConsumer c = cList.get(cList.size() - 1); + try { + long released = c.spill(required - got, consumer); + if (released > 0) { + logger.debug("Task {} released {} from {} for {}", taskAttemptId, + Utils.bytesToString(released), c, consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); + if (got >= required) { + break; + } + } else { + cList.remove(cList.size() - 1); + if (cList.isEmpty()) { + sortedConsumers.remove(currentEntry.getKey()); + } + } + } catch (ClosedByInterruptException e) { + // This called by user to kill a task (e.g: speculative task). + logger.error("error while calling spill() on " + c, e); + throw new RuntimeException(e.getMessage()); + } catch (IOException e) { + logger.error("error while calling spill() on " + c, e); + // checkstyle.off: RegexpSinglelineJava + throw new SparkOutOfMemoryError("error while calling spill() on " + c + " : " + + e.getMessage()); + // checkstyle.on: RegexpSinglelineJava + } + } + } + + // call spill() on itself + if (got < required) { + try { + long released = consumer.spill(required - got, consumer); + if (released > 0) { + logger.debug("Task {} released {} from itself ({})", taskAttemptId, + Utils.bytesToString(released), consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); + } + } catch (ClosedByInterruptException e) { + // This called by user to kill a task (e.g: speculative task). + logger.error("error while calling spill() on " + consumer, e); + throw new RuntimeException(e.getMessage()); + } catch (IOException e) { + logger.error("error while calling spill() on " + consumer, e); + // checkstyle.off: RegexpSinglelineJava + throw new SparkOutOfMemoryError("error while calling spill() on " + consumer + " : " + + e.getMessage()); + // checkstyle.on: RegexpSinglelineJava + } + } + + consumers.add(consumer); + logger.debug("Task {} acquired {} for {}", taskAttemptId, Utils.bytesToString(got), consumer); + return got; + } + } + + /** + * Acquire extended memory + * When extended memory is acqurired, spill will not be triggered. + * @param required + * @return + */ + public long acquireExtendedMemory(long required, MemoryConsumer consumer) { + assert(required >= 0); + logger.debug("Task {} acquire {} bytes PMem memory.", taskAttemptId, Utils.bytesToString(required)); + synchronized (this) { + long got = memoryManager.acquireExtendedMemory(required, taskAttemptId); + logger.debug("Task {} got {} bytes PMem memory.", taskAttemptId, Utils.bytesToString(got)); + // The MemoryConsumer which acquired extended memory should be traced in TaskMemoryManagr. + // Not very sure about whether it should be added to the consumers here. Maybe should maintain + // another list for consumers which use extended memory. + consumers.add(consumer); + return got; + } + } + + /** + * Release N bytes of execution memory for a MemoryConsumer. + */ + public void releaseExecutionMemory(long size, MemoryConsumer consumer) { + logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer); + memoryManager.releaseExecutionMemory(size, taskAttemptId, consumer.getMode()); + } + + public long acquireExtendedMemory(long required) { + assert(required >= 0); + logger.debug("Task {} acquire {} bytes PMem memory.", taskAttemptId, Utils.bytesToString(required)); + synchronized (this) { + long got = memoryManager.acquireExtendedMemory(required, taskAttemptId); + return got; + } + } + + public void releaseExtendedMemory(long size) { + logger.debug("Task {} release {} PMem space.", taskAttemptId, Utils.bytesToString(size)); + memoryManager.releaseExtendedMemory(size, taskAttemptId); + } + /** + * Rlease extended memory + * @param size + */ + public void releaseExtendedMemory(long size, MemoryConsumer consumer) { + logger.debug("Task {} release {} PMem space.", taskAttemptId, Utils.bytesToString(size)); + memoryManager.releaseExtendedMemory(size, taskAttemptId); + } + + /** + * Dump the memory usage of all consumers. + */ + public void showMemoryUsage() { + logger.info("Memory used in task " + taskAttemptId); + synchronized (this) { + long memoryAccountedForByConsumers = 0; + for (MemoryConsumer c: consumers) { + long totalMemUsage = c.getUsed(); + memoryAccountedForByConsumers += totalMemUsage; + if (totalMemUsage > 0) { + logger.info("Acquired by " + c + ": " + Utils.bytesToString(totalMemUsage)); + } + } + long memoryNotAccountedFor = + memoryManager.getExecutionMemoryUsageForTask(taskAttemptId) - memoryAccountedForByConsumers; + logger.info( + "{} bytes of memory were used by task {} but are not associated with specific consumers", + memoryNotAccountedFor, taskAttemptId); + logger.info( + "{} bytes of memory are used for execution and {} bytes of memory are used for storage", + memoryManager.executionMemoryUsed(), memoryManager.storageMemoryUsed()); + } + } + + /** + * Return the page size in bytes. + */ + public long pageSizeBytes() { + return memoryManager.pageSizeBytes(); + } + + /** + * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is + * intended for allocating large blocks of Tungsten memory that will be shared between operators. + * + * Returns `null` if there was not enough memory to allocate the page. May return a page that + * contains fewer bytes than requested, so callers should verify the size of returned pages. + * + * @throws TooLargePageException + */ + public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { + return allocatePage(size, consumer, false); + } + + public MemoryBlock allocatePage(long size, MemoryConsumer consumer, boolean useExtendedMem) { + assert(consumer != null); + assert(consumer.getMode() == tungstenMemoryMode); + if (size > MAXIMUM_PAGE_SIZE_BYTES) { + throw new TooLargePageException(size); + } + + long acquired = 0L; + if (useExtendedMem) { + acquired = acquireExtendedMemory(size, consumer); + } else { + acquired = acquireExecutionMemory(size, consumer); + } + if (acquired <= 0) { + return null; + } + + final int pageNumber; + synchronized (this) { + pageNumber = allocatedPages.nextClearBit(0); + if (pageNumber >= PAGE_TABLE_SIZE) { + if (useExtendedMem) { + releaseExtendedMemory(acquired, consumer); + } else { + releaseExecutionMemory(acquired, consumer); + } + throw new IllegalStateException( + "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); + } + allocatedPages.set(pageNumber); + } + MemoryBlock page = null; + try { + if (useExtendedMem) { + page = memoryManager.extendedMemoryAllocator().allocate(size); + page.isExtendedMemory(true); + } else { + page = memoryManager.tungstenMemoryAllocator().allocate(acquired); + } + } catch (OutOfMemoryError e) { + logger.warn("Failed to allocate a page ({} bytes), try again.", acquired); + // there is no enough memory actually, it means the actual free memory is smaller than + // MemoryManager thought, we should keep the acquired memory. + synchronized (this) { + acquiredButNotUsed += acquired; + allocatedPages.clear(pageNumber); + } + if (useExtendedMem) { + // will not force spill when use extended mem + return null; + } else { + // this could trigger spilling to free some pages. + return allocatePage(size, consumer); + } + } + + page.pageNumber = pageNumber; + pageTable[pageNumber] = page; + if (logger.isTraceEnabled()) { + logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); + } + return page; + } + + /** + * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. + */ + public void freePage(MemoryBlock page, MemoryConsumer consumer) { + assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : + "Called freePage() on memory that wasn't allocated with allocatePage()"; + assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; + assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; + assert(allocatedPages.get(page.pageNumber)); + pageTable[page.pageNumber] = null; + synchronized (this) { + allocatedPages.clear(page.pageNumber); + } + if (logger.isTraceEnabled()) { + logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + } + long pageSize = page.size(); + // Clear the page number before passing the block to the MemoryAllocator's free(). + // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed + // page has been inappropriately directly freed without calling TMM.freePage(). + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + if (page.isExtendedMemory) { + memoryManager.extendedMemoryAllocator().free(page); + releaseExtendedMemory(pageSize, consumer); + } else { + memoryManager.tungstenMemoryAllocator().free(page); + releaseExecutionMemory(pageSize, consumer); + } + } + + /** + * @param size + * @return + */ + public MemoryBlock allocatePMemPage(long size) { + if (size > MAXIMUM_PAGE_SIZE_BYTES) { + throw new TooLargePageException(size); + } + final int pageNumber; + synchronized (this) { + pageNumber = allocatedPages.nextClearBit(0); + if (pageNumber >= PAGE_TABLE_SIZE) { + throw new IllegalStateException( + "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); + } + allocatedPages.set(pageNumber); + } + MemoryBlock page = null; + try { + page = memoryManager.getUsablePMemPage(size); + if (page == null) { + page = memoryManager.extendedMemoryAllocator().allocate(size); + memoryManager.addPMemPages(page); + } else { + logger.debug("reuse pmem page."); + } + } catch (OutOfMemoryError e) { + logger.debug("Failed to allocate a PMem page ({} bytes).", size); + return null; + } + page.isExtendedMemory(true); + page.pageNumber = pageNumber; + pageTable[pageNumber] = page; + if (logger.isTraceEnabled()) { + logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); + } + return page; + + } + + public void freePMemPage(MemoryBlock page, MemoryConsumer consumer) { + assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : + "Called freePage() on memory that wasn't allocated with allocatePage()"; + assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; + assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + "Called freePage() on a memory block that has already been freed"; + assert(allocatedPages.get(page.pageNumber)); + pageTable[page.pageNumber] = null; + synchronized (this) { + allocatedPages.clear(page.pageNumber); + } + if (logger.isTraceEnabled()) { + logger.trace("Freed PMem page number {} ({} bytes)", page.pageNumber, page.size()); + } + // Clear the page number before passing the block to the MemoryAllocator's free(). + // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed + // page has been inappropriately directly freed without calling TMM.freePage(). + + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + // not really free the PMem page for future page reuse + } + + /** + * Given a memory page and offset within that page, encode this address into a 64-bit long. + * This address will remain valid as long as the corresponding page has not been freed. + * + * @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/ + * @param offsetInPage an offset in this page which incorporates the base offset. In other words, + * this should be the value that you would pass as the base offset into an + * UNSAFE call (e.g. page.baseOffset() + something). + * @return an encoded page address. + */ + public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { + if (tungstenMemoryMode == MemoryMode.OFF_HEAP) { + // In off-heap mode, an offset is an absolute address that may require a full 64 bits to + // encode. Due to our page size limitation, though, we can convert this into an offset that's + // relative to the page's base offset; this relative offset will fit in 51 bits. + offsetInPage -= page.getBaseOffset(); + } + return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + } + + @VisibleForTesting + public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { + assert (pageNumber >= 0) : "encodePageNumberAndOffset called with invalid page"; + return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } + + @VisibleForTesting + public static int decodePageNumber(long pagePlusOffsetAddress) { + return (int) (pagePlusOffsetAddress >>> OFFSET_BITS); + } + + private static long decodeOffset(long pagePlusOffsetAddress) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + } + + /** + * Get the page associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public Object getPage(long pagePlusOffsetAddress) { + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + assert (page.getBaseObject() != null); + return page.getBaseObject(); + } else { + return null; + } + } + + public MemoryBlock getOriginalPage(long pagePlusOffsetAddress) { + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + assert (page.getBaseObject() != null); + return page; + } else { + return null; + } + } + + /** + * Get the offset associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public long getOffsetInPage(long pagePlusOffsetAddress) { + final long offsetInPage = decodeOffset(pagePlusOffsetAddress); + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { + return offsetInPage; + } else { + // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we + // converted the absolute address into a relative address. Here, we invert that operation: + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + return page.getBaseOffset() + offsetInPage; + } + } + + /** + * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return + * value can be used to detect memory leaks. + */ + public long cleanUpAllAllocatedMemory() { + synchronized (this) { + for (MemoryConsumer c: consumers) { + if (c != null && c.getUsed() > 0) { + // In case of failed task, it's normal to see leaked memory + logger.debug("unreleased " + Utils.bytesToString(c.getUsed()) + " memory from " + c); + } + } + consumers.clear(); + + for (MemoryBlock page : pageTable) { + if (page != null) { + logger.debug("unreleased page: " + page + " in task " + taskAttemptId); + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + if (!page.isExtendedMemory){ + memoryManager.tungstenMemoryAllocator().free(page); + } else { + memoryManager.extendedMemoryAllocator().free(page); + } + } + } + Arrays.fill(pageTable, null); + } + + // release the memory that is not used by any consumer (acquired for pages in tungsten mode). + memoryManager.releaseExecutionMemory(acquiredButNotUsed, taskAttemptId, tungstenMemoryMode); + memoryManager.releaseAllExtendedMemoryForTask(taskAttemptId); + + return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); + } + + /** + * Returns the memory consumption, in bytes, for the current task. + */ + public long getMemoryConsumptionForThisTask() { + return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId); + } + + /** + * Returns Tungsten memory mode + */ + public MemoryMode getTungstenMemoryMode() { + return tungstenMemoryMode; + } +} diff --git a/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index d4cc153e..e7463908 100644 --- a/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -25,10 +25,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closeables; -import com.intel.oap.common.storage.stream.ChunkInputStream; -import com.intel.oap.common.storage.stream.DataStore; -import org.apache.spark.internal.config.package$; -import org.apache.spark.memory.PMemManagerInitializer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -286,7 +282,7 @@ private void advanceToNextPage() { } try { Closeables.close(reader, /* swallowIOException = */ false); - reader = spillWriters.getFirst().getReader(serializerManager); + reader = spillWriters.getFirst().getReader(serializerManager, null); recordsInPage = -1; } catch (IOException e) { // Scala iterator does not handle exception @@ -397,23 +393,10 @@ public void remove() { } private void handleFailedDelete() { - // remove the spill file from disk or pmem - boolean pMemSpillEnabled = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get( - package$.MODULE$.MEMORY_SPILL_PMEM_ENABLED()); + // remove the spill file from disk File file = spillWriters.removeFirst().getFile(); - if(pMemSpillEnabled == true) { - try { - ChunkInputStream cis = ChunkInputStream.getChunkInputStreamInstance(file.toString(), - new DataStore(PMemManagerInitializer.getPMemManager(), - PMemManagerInitializer.getProperties())); - cis.free(); - } catch (IOException e) { - logger.debug(e.toString()); - } - } else { - if (file != null && file.exists() && !file.delete()) { - logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); - } + if (file != null && file.exists() && !file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); } } } @@ -835,22 +818,6 @@ public void free() { } assert(dataPages.isEmpty()); - deleteSpillFiles(); - } - - /** - * Deletes any spill files created by this consumer. - */ - private void deleteSpillFiles() { - boolean pMemSpillEnabled = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get( - package$.MODULE$.MEMORY_SPILL_PMEM_ENABLED()); - if (pMemSpillEnabled == true) - deletePMemSpillFiles(); - else - deleteDiskSpillFiles(); - } - - private void deleteDiskSpillFiles() { while (!spillWriters.isEmpty()) { File file = spillWriters.removeFirst().getFile(); if (file != null && file.exists()) { @@ -861,20 +828,6 @@ private void deleteDiskSpillFiles() { } } - private void deletePMemSpillFiles() { - while (!spillWriters.isEmpty()) { - File file = spillWriters.removeFirst().getFile(); - try { - ChunkInputStream cis = ChunkInputStream.getChunkInputStreamInstance(file.toString(), - new DataStore(PMemManagerInitializer.getPMemManager(), - PMemManagerInitializer.getProperties())); - cis.free(); - } catch (IOException e) { - logger.debug(e.toString()); - } - } - } - public TaskMemoryManager getTaskMemoryManager() { return taskMemoryManager; } diff --git a/src/main/java/org/apache/spark/unsafe/memory/ExtendedMemoryAllocator.java b/src/main/java/org/apache/spark/unsafe/memory/ExtendedMemoryAllocator.java new file mode 100644 index 00000000..7223068a --- /dev/null +++ b/src/main/java/org/apache/spark/unsafe/memory/ExtendedMemoryAllocator.java @@ -0,0 +1,21 @@ + +package org.apache.spark.unsafe.memory; +import com.intel.oap.common.unsafe.PersistentMemoryPlatform; + +public class ExtendedMemoryAllocator implements MemoryAllocator{ + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + long address = PersistentMemoryPlatform.allocateVolatileMemory(size); + MemoryBlock memoryBlock = new MemoryBlock(null, address, size); + + return memoryBlock; + } + + @Override + public void free(MemoryBlock memoryBlock) { + assert (memoryBlock.getBaseObject() == null) : + "baseObject not null; are you trying to use the AEP-heap allocator to free on-heap memory?"; + PersistentMemoryPlatform.freeMemory(memoryBlock.getBaseOffset()); + } +} \ No newline at end of file diff --git a/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java new file mode 100644 index 00000000..b47370f4 --- /dev/null +++ b/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +public interface MemoryAllocator { + + /** + * Whether to fill newly allocated and deallocated memory with 0xa5 and 0x5a bytes respectively. + * This helps catch misuse of uninitialized or freed memory, but imposes some overhead. + */ + boolean MEMORY_DEBUG_FILL_ENABLED = Boolean.parseBoolean( + System.getProperty("spark.memory.debugFill", "false")); + + // Same as jemalloc's debug fill values. + byte MEMORY_DEBUG_FILL_CLEAN_VALUE = (byte)0xa5; + byte MEMORY_DEBUG_FILL_FREED_VALUE = (byte)0x5a; + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `fill(0)` on the result if this is necessary). + */ + MemoryBlock allocate(long size) throws OutOfMemoryError; + + void free(MemoryBlock memory); + + MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + + MemoryAllocator HEAP = new HeapMemoryAllocator(); + + MemoryAllocator EXTENDED = new ExtendedMemoryAllocator(); +} diff --git a/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java new file mode 100644 index 00000000..7517b3a0 --- /dev/null +++ b/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + */ +public class MemoryBlock extends MemoryLocation { + + /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ + public static final int NO_PAGE_NUMBER = -1; + + /** + * Special `pageNumber` value for marking pages that have been freed in the TaskMemoryManager. + * We set `pageNumber` to this value in TaskMemoryManager.freePage() so that MemoryAllocator + * can detect if pages which were allocated by TaskMemoryManager have been freed in the TMM + * before being passed to MemoryAllocator.free() (it is an error to allocate a page in + * TaskMemoryManager and then directly free it in a MemoryAllocator without going through + * the TMM freePage() call). + */ + public static final int FREED_IN_TMM_PAGE_NUMBER = -2; + + /** + * Special `pageNumber` value for pages that have been freed by the MemoryAllocator. This allows + * us to detect double-frees. + */ + public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; + + private final long length; + + /** + * Optional page number; used when this MemoryBlock represents a page allocated by a + * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, + * which lives in a different package. + */ + public int pageNumber = NO_PAGE_NUMBER; + + /** + * Indicate the memory block is on extended memory or not. + */ + public boolean isExtendedMemory = false; + + public MemoryBlock(@Nullable Object obj, long offset, long length) { + super(obj, offset); + this.length = length; + } + + /** + * Returns the size of the memory block. + */ + public long size() { + return length; + } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static MemoryBlock fromLongArray(final long[] array) { + return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + } + + /** + * Fills the memory block with the specified byte value. + */ + public void fill(byte value) { + Platform.setMemory(obj, offset, length, value); + } + + /** + * Whether this memory block is on extended memory. + * @return + */ + public boolean isExtendedMemory() { + return isExtendedMemory; + } + + /** + * set whether this memory block is on extended memory. + * @param isExtended + */ + public void isExtendedMemory(boolean isExtended) { + this.isExtendedMemory = isExtended; + } +} diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReader.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReader.java new file mode 100644 index 00000000..1d2b1ee3 --- /dev/null +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReader.java @@ -0,0 +1,87 @@ +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; + +import java.io.Closeable; +import java.util.LinkedList; + +public final class PMemReader extends UnsafeSorterIterator implements Closeable { + private int recordLength; + private long keyPrefix; + private int numRecordsRemaining; + private int numRecords; + private LinkedList pMemPages; + private MemoryBlock pMemPage = null; + private int readingPageIndex = 0; + private int readedRecordsInCurrentPage = 0; + private int numRecordsInpage = 0; + private long offset = 0; + private byte[] arr = new byte[1024 * 1024]; + private Object baseObject = arr; + public PMemReader(LinkedList pMemPages, int numRecords) { + this.pMemPages = pMemPages; + this.numRecordsRemaining = this.numRecords = numRecords; + } + @Override + public void loadNext() { + assert (readingPageIndex <= pMemPages.size()) + : "Illegal state: Pages finished read but hasNext() is true."; + if(pMemPage == null || readedRecordsInCurrentPage == numRecordsInpage) { + // read records from each page + pMemPage = pMemPages.get(readingPageIndex++); + readedRecordsInCurrentPage = 0; + numRecordsInpage = Platform.getInt(null, pMemPage.getBaseOffset()); + offset = pMemPage.getBaseOffset() + 4; + } + // record: BaseOffSet, record length, KeyPrefix, record value + keyPrefix = Platform.getLong(null, offset); + offset += 8; + recordLength = Platform.getInt(null, offset); + offset += 4; + if (recordLength > arr.length) { + arr = new byte[recordLength]; + baseObject = arr; + } + Platform.copyMemory(null, offset , baseObject, Platform.BYTE_ARRAY_OFFSET, recordLength); + offset += recordLength; + readedRecordsInCurrentPage ++; + numRecordsRemaining --; + + + } + @Override + public int getNumRecords() { + return numRecords; + } + + @Override + public boolean hasNext() { + return (numRecordsRemaining > 0); + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return Platform.BYTE_ARRAY_OFFSET; + } + + @Override + public int getRecordLength() { + return recordLength; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } + + @Override + public void close() { + // do nothing here + } +} diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReaderForUnsafeExternalSorter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReaderForUnsafeExternalSorter.java new file mode 100644 index 00000000..f970b0dd --- /dev/null +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemReaderForUnsafeExternalSorter.java @@ -0,0 +1,141 @@ +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.SparkEnv; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.internal.config.package$; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public final class PMemReaderForUnsafeExternalSorter extends UnsafeSorterIterator implements Closeable { + private static final Logger logger = LoggerFactory.getLogger(PMemReaderForUnsafeExternalSorter.class); + private int recordLength; + private long keyPrefix; + private int numRecordsRemaining; + private int numRecords; + private LongArray sortedArray; + private int position; + private byte[] arr = new byte[1024 * 1024]; + private byte[] bytes = new byte[1024 * 1024]; + private Object baseObject = arr; + private TaskMetrics taskMetrics; + private long startTime; + private ByteBuffer byteBuffer; + public PMemReaderForUnsafeExternalSorter( + LongArray sortedArray, int position, int numRecords, TaskMetrics taskMetrics) { + this.sortedArray = sortedArray; + this.position = position; + this.numRecords = numRecords; + this.numRecordsRemaining = numRecords - position/2; + this.taskMetrics = taskMetrics; + int readBufferSize = SparkEnv.get() == null? 8 * 1024 * 1024 : + (int) SparkEnv.get().conf().get(package$.MODULE$.MEMORY_SPILL_PMEM_READ_BUFFERSIZE()); + logger.info("PMem read buffer size is:" + Utils.bytesToString(readBufferSize)); + this.byteBuffer = ByteBuffer.wrap(new byte[readBufferSize]); + byteBuffer.flip(); + byteBuffer.order(ByteOrder.nativeOrder()); + } + + @Override + public void loadNext() { + if (!byteBuffer.hasRemaining()) { + boolean refilled = refill(); + if (!refilled) { + logger.error("Illegal status: records finished read but hasNext() is true."); + } + } + keyPrefix = byteBuffer.getLong(); + recordLength = byteBuffer.getInt(); + if (recordLength > arr.length) { + arr = new byte[recordLength]; + baseObject = arr; + } + byteBuffer.get(arr, 0, recordLength); + numRecordsRemaining --; + } + + @Override + public int getNumRecords() { + return numRecords; + } + + /** + * load more PMem records in the buffer + */ + private boolean refill() { + byteBuffer.clear(); + int nRead = loadData(); + byteBuffer.flip(); + if (nRead <= 0) { + return false; + } + return true; + } + + private int loadData() { + // no records remaining to read + if (position >= numRecords * 2) + return -1; + int bufferPos = 0; + int capacity = byteBuffer.capacity(); + while (bufferPos < capacity && position < numRecords * 2) { + long curRecordAddress = sortedArray.get(position); + int recordLen = Platform.getInt(null, curRecordAddress); + // length + keyprefix + record length + int length = Integer.BYTES + Long.BYTES + recordLen; + if (length > capacity) { + logger.error("single record size exceeds PMem read buffer. Please increase buffer size."); + } + if (bufferPos + length <= capacity) { + long curKeyPrefix = sortedArray.get(position + 1); + if (length > bytes.length) { + bytes = new byte[length]; + } + Platform.putLong(bytes, Platform.BYTE_ARRAY_OFFSET, curKeyPrefix); + Platform.copyMemory(null, curRecordAddress, bytes, Platform.BYTE_ARRAY_OFFSET + Long.BYTES, length - Long.BYTES); + byteBuffer.put(bytes, 0, length); + bufferPos += length; + position += 2; + } else { + break; + } + } + return bufferPos; + } + + @Override + public boolean hasNext() { + return (numRecordsRemaining > 0); + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return Platform.BYTE_ARRAY_OFFSET; + } + + @Override + public int getRecordLength() { + return recordLength; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } + + @Override + public void close() { + // do nothing here + } +} diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemSpillWriterFactory.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemSpillWriterFactory.java new file mode 100644 index 00000000..97d7c61d --- /dev/null +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemSpillWriterFactory.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.storage.BlockManager; + +import java.io.IOException; + +public class PMemSpillWriterFactory { + public static SpillWriterForUnsafeSorter getSpillWriter( + PMemSpillWriterType writerType, + UnsafeExternalSorter externalSorter, + UnsafeSorterIterator sortedIterator, + int numberOfRecordsToWritten, + SerializerManager serializerManager, + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics, + TaskMetrics taskMetrics, + boolean spillToPMEMEnabled, + boolean isSorted) throws IOException { + if (spillToPMEMEnabled && writerType == PMemSpillWriterType.MEM_COPY_ALL_DATA_PAGES_TO_PMEM_WITHLONGARRAY){ + SortedIteratorForSpills sortedSpillIte = SortedIteratorForSpills.createFromExistingSorterIte( + (UnsafeInMemorySorter.SortedIterator)sortedIterator, + externalSorter.getInMemSorter()); + return new PMemWriter( + externalSorter, + sortedSpillIte, + isSorted, + numberOfRecordsToWritten, + serializerManager, + blockManager, + fileBufferSize, + writeMetrics, + taskMetrics); + } else { + if (sortedIterator == null) { + sortedIterator = externalSorter.getInMemSorter().getSortedIterator(); + } + if (spillToPMEMEnabled && sortedIterator instanceof UnsafeInMemorySorter.SortedIterator){ + + if (writerType == PMemSpillWriterType.WRITE_SORTED_RECORDS_TO_PMEM) { + SortedIteratorForSpills sortedSpillIte = SortedIteratorForSpills.createFromExistingSorterIte( + (UnsafeInMemorySorter.SortedIterator)sortedIterator, + externalSorter.getInMemSorter()); + return new SortedPMemPageSpillWriter( + externalSorter, + sortedSpillIte, + numberOfRecordsToWritten, + serializerManager, + blockManager, + fileBufferSize, + writeMetrics, + taskMetrics); + } + if (spillToPMEMEnabled && writerType == PMemSpillWriterType.STREAM_SPILL_TO_PMEM) { + return new UnsafeSorterStreamSpillWriter( + blockManager, + fileBufferSize, + sortedIterator, + numberOfRecordsToWritten, + serializerManager, + writeMetrics, + taskMetrics); + } + } else { + return new UnsafeSorterSpillWriter( + blockManager, + fileBufferSize, + sortedIterator, + numberOfRecordsToWritten, + serializerManager, + writeMetrics, + taskMetrics); + } + + } + return null; + } +} diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemSpillWriterType.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemSpillWriterType.java new file mode 100644 index 00000000..e6bad61d --- /dev/null +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemSpillWriterType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +public enum PMemSpillWriterType { + STREAM_SPILL_TO_PMEM, + MEM_COPY_ALL_DATA_PAGES_TO_PMEM, + MEM_COPY_ALL_DATA_PAGES_TO_PMEM_WITHLONGARRAY, + WRITE_SORTED_RECORDS_TO_PMEM +} diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemWriter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemWriter.java new file mode 100644 index 00000000..2db27343 --- /dev/null +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/PMemWriter.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import com.intel.oap.common.unsafe.PersistentMemoryPlatform; + +import org.apache.spark.SparkEnv; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.internal.config.package$; +import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +/** + * In this writer, records page along with LongArray page are both dumped to PMem when spill happens + */ +public final class PMemWriter extends UnsafeSorterPMemSpillWriter { + private static final Logger logger = LoggerFactory.getLogger(PMemWriter.class); + private LongArray sortedArray; + private HashMap pageMap = new HashMap<>(); + private int position; + private LinkedList allocatedDramPages; + private MemoryBlock pMemPageForLongArray; + private UnsafeSorterSpillWriter diskSpillWriter; + private BlockManager blockManager; + private SerializerManager serializerManager; + private int fileBufferSize; + private boolean isSorted; + private int totalRecordsWritten; + private final boolean spillToPMemConcurrently = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get( + package$.MODULE$.MEMORY_SPILL_PMEM_SORT_BACKGROUND()); + private final boolean pMemClflushEnabled = SparkEnv.get() != null && + (boolean)SparkEnv.get().conf().get(package$.MODULE$.MEMORY_SPILL_PMEM_CLFLUSH_ENABLED()); + + public PMemWriter( + UnsafeExternalSorter externalSorter, + SortedIteratorForSpills sortedIterator, + boolean isSorted, + int numberOfRecordsToWritten, + SerializerManager serializerManager, + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics, + TaskMetrics taskMetrics) { + // SortedIterator is null or readingIterator from UnsafeExternalSorter. + // But it isn't used in this PMemWriter, only for keep same constructor with other spill writers. + super(externalSorter, sortedIterator, numberOfRecordsToWritten, writeMetrics, taskMetrics); + this.allocatedDramPages = externalSorter.getAllocatedPages(); + this.blockManager = blockManager; + this.serializerManager = serializerManager; + this.fileBufferSize = fileBufferSize; + this.isSorted = isSorted; + // In the case that spill happens when iterator isn't sorted yet, the valid records + // will be [0, inMemsorter.numRecords]. When iterator is sorted, the valid records will be + // [position/2, inMemsorter.numRecords] + this.totalRecordsWritten = externalSorter.getInMemSorter().numRecords(); + } + + @Override + public void write() throws IOException { + // write records based on externalsorter + // try to allocate all needed PMem pages before spill to PMem + UnsafeInMemorySorter inMemSorter = externalSorter.getInMemSorter(); + if (allocatePMemPages(allocatedDramPages, inMemSorter.getArray().memoryBlock())) { + if (spillToPMemConcurrently && !isSorted) { + logger.info("Concurrent PMem write/records sort"); + long writeDuration = 0; + ExecutorService executorService = Executors.newSingleThreadExecutor(); + Future future = executorService.submit(()->dumpPagesToPMem()); + externalSorter.getInMemSorter().getSortedIterator(); + try { + writeDuration = future.get(); + } catch (InterruptedException | ExecutionException e) { + logger.error(e.getMessage()); + } + executorService.shutdownNow(); + updateLongArray(inMemSorter.getArray(), totalRecordsWritten, 0); + } else if(!isSorted) { + dumpPagesToPMem(); + // get sorted iterator + externalSorter.getInMemSorter().getSortedIterator(); + // update LongArray + updateLongArray(inMemSorter.getArray(), totalRecordsWritten, 0); + } else { + dumpPagesToPMem(); + // get sorted iterator + assert(sortedIterator != null); + updateLongArray(inMemSorter.getArray(), totalRecordsWritten, sortedIterator.getPosition()); + } + } else { + // fallback to disk spill + if (diskSpillWriter == null) { + diskSpillWriter = new UnsafeSorterSpillWriter( + blockManager, + fileBufferSize, + sortedIterator, + numberOfRecordsToWritten, + serializerManager, + writeMetrics, + taskMetrics); + } + diskSpillWriter.write(false); + } + } + + public boolean allocatePMemPages(LinkedList dramPages, MemoryBlock longArrayPage) { + for (MemoryBlock page: dramPages) { + MemoryBlock pMemBlock = taskMemoryManager.allocatePMemPage(page.size()); + if (pMemBlock != null) { + allocatedPMemPages.add(pMemBlock); + pageMap.put(page, pMemBlock); + } else { + pageMap.clear(); + return false; + } + } + pMemPageForLongArray = taskMemoryManager.allocatePMemPage(longArrayPage.size()); + if (pMemPageForLongArray != null) { + allocatedPMemPages.add(pMemPageForLongArray); + pageMap.put(longArrayPage, pMemPageForLongArray); + } else { + pageMap.clear(); + return false; + } + return (allocatedPMemPages.size() == dramPages.size() + 1); + } + + private long dumpPagesToPMem() { + long dumpTime = System.nanoTime(); + for (MemoryBlock page : allocatedDramPages) { + dumpPageToPMem(page); + } + long dumpDuration = System.nanoTime() - dumpTime; + return dumpDuration; + + } + + private void dumpPageToPMem(MemoryBlock page) { + MemoryBlock pMemBlock = pageMap.get(page); + PersistentMemoryPlatform.copyMemory( + page.getBaseObject(), page.getBaseOffset(), + null, pMemBlock.getBaseOffset(), page.size(), + pMemClflushEnabled); + writeMetrics.incBytesWritten(page.size()); + } + + public void updateLongArray(LongArray sortedArray, int numRecords, int position) { + this.position = position; + while (position < numRecords * 2){ + // update recordPointer in this array + long originalRecordPointer = sortedArray.get(position); + MemoryBlock page = taskMemoryManager.getOriginalPage(originalRecordPointer); + long offset = taskMemoryManager.getOffsetInPage(originalRecordPointer) - page.getBaseOffset(); + MemoryBlock pMemBlock = pageMap.get(page); + long pMemOffset = pMemBlock.getBaseOffset() + offset; + sortedArray.set(position, pMemOffset); + position += 2; + } + // copy the LongArray to PMem + MemoryBlock arrayBlock = sortedArray.memoryBlock(); + MemoryBlock pMemBlock = pageMap.get(arrayBlock); + PersistentMemoryPlatform.copyMemory( + arrayBlock.getBaseObject(), arrayBlock.getBaseOffset(), + null, pMemBlock.getBaseOffset(), arrayBlock.size(), + pMemClflushEnabled); + writeMetrics.incBytesWritten(pMemBlock.size()); + this.sortedArray = new LongArray(pMemBlock); + } + + @Override + public UnsafeSorterIterator getSpillReader() throws IOException { + // TODO: consider partial spill to PMem + Disk. + if (diskSpillWriter != null) { + return diskSpillWriter.getSpillReader(); + } else { + return new PMemReaderForUnsafeExternalSorter(sortedArray, position, totalRecordsWritten, taskMetrics); + } + } + + public void clearAll() { + freeAllPMemPages(); + if (diskSpillWriter != null) { + diskSpillWriter.clearAll(); + } + } + + @Override + public int recordsSpilled() { + return numberOfRecordsToWritten; + } + + @Override + public void freeAllPMemPages() { + for( MemoryBlock page: allocatedPMemPages) { + taskMemoryManager.freePMemPage(page, externalSorter); + } + allocatedPMemPages.clear(); + } +} diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/SortedIteratorForSpills.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/SortedIteratorForSpills.java new file mode 100644 index 00000000..5c2961fd --- /dev/null +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/SortedIteratorForSpills.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.TaskContext; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; + +public class SortedIteratorForSpills extends UnsafeSorterIterator { + private LongArray sortedArray; + private final int numRecords; + private int position; + private int offset; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + private int recordLength; + private long currentPageNumber; + private final TaskContext taskContext = TaskContext.get(); + private final TaskMemoryManager memoryManager; + + /** + * Construct an iterator to read the spill. + * @param array the array here should be already sorted. + * @param numRecords the number of the records recorded in this LongArray + * @param offset + */ + public SortedIteratorForSpills( + final TaskMemoryManager memoryManager, + LongArray array, + int numRecords, + int offset) { + this.memoryManager = memoryManager; + this.sortedArray = array; + this.numRecords = numRecords; + this.position = 0; + this.offset = offset; + } + + public static SortedIteratorForSpills createFromExistingSorterIte( + UnsafeInMemorySorter.SortedIterator sortedIte, + UnsafeInMemorySorter inMemSorter) { + if (sortedIte == null) { + return null; + } + TaskMemoryManager taskMemoryManager = inMemSorter.getTaskMemoryManager(); + LongArray array = inMemSorter.getLongArray(); + int numberRecords = sortedIte.getNumRecords(); + int offset = sortedIte.getOffset(); + int position = sortedIte.getPosition(); + SortedIteratorForSpills spillIte = new SortedIteratorForSpills(taskMemoryManager, array,numberRecords,offset); + spillIte.pointTo(position); + return spillIte; + } + + @Override + public int getNumRecords() { + return numRecords; + } + + @Override + public boolean hasNext() { + return position / 2 < numRecords; + } + + @Override + public void loadNext() { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); + } + loadPosition(); + } + /** + * load the record of current position and move to end of the record. + */ + private void loadPosition() { + // This pointer points to a 4-byte record length, followed by the record's bytes + final long recordPointer = sortedArray.get(offset + position); + currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + baseObject = memoryManager.getPage(recordPointer); + // Skip over record length + baseOffset = memoryManager.getOffsetInPage(recordPointer) + uaoSize; + recordLength = UnsafeAlignedOffset.getSize(baseObject, baseOffset - uaoSize); + keyPrefix = sortedArray.get(offset + position + 1); + position += 2; + } + + /** + * point to a given position. + * @param pos + */ + public void pointTo(int pos) { + if (pos % 2 != 0) { + throw new IllegalArgumentException("Can't point to the middle of a record."); + } + position = pos; + } + + @Override + public Object getBaseObject() { return baseObject; } + + @Override + public long getBaseOffset() { return baseOffset; } + + public long getCurrentPageNumber() { return currentPageNumber; } + + @Override + public int getRecordLength() { return recordLength; } + + @Override + public long getKeyPrefix() { return keyPrefix; } + + public LongArray getLongArray() { return sortedArray; } + + public int getPosition() { return position; } +} diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/SortedPMemPageSpillWriter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/SortedPMemPageSpillWriter.java new file mode 100644 index 00000000..bd8ca9c8 --- /dev/null +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/SortedPMemPageSpillWriter.java @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import com.intel.oap.common.unsafe.PersistentMemoryPlatform; + +import org.apache.spark.SparkEnv; +import org.apache.spark.internal.config.package$; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.serializer.SerializerManager; +import java.io.IOException; +import java.util.LinkedHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SortedPMemPageSpillWriter extends UnsafeSorterPMemSpillWriter { + private static final Logger sorted_logger = LoggerFactory.getLogger(SortedPMemPageSpillWriter.class); + private MemoryBlock currentPMemPage = null; + private long currentOffsetInPage = 0L; + private int currentNumOfRecordsInPage = 0; + //Page -> record number map + private LinkedHashMap pageNumOfRecMap = new LinkedHashMap(); + private int numRecords = 0; + private int numRecordsOnPMem = 0; + + private BlockManager blockManager; + private SerializerManager serializerManager; + private int fileBufferSize = 0; + private UnsafeSorterSpillWriter diskSpillWriter; + + private final boolean pMemClflushEnabled = SparkEnv.get() != null && + (boolean)SparkEnv.get().conf().get(package$.MODULE$.MEMORY_SPILL_PMEM_CLFLUSH_ENABLED()); + + public SortedPMemPageSpillWriter( + UnsafeExternalSorter externalSorter, + SortedIteratorForSpills sortedIterator, + int numberOfRecordsToWritten, + SerializerManager serializerManager, + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics, + TaskMetrics taskMetrics) { + super(externalSorter, sortedIterator, numberOfRecordsToWritten, writeMetrics, taskMetrics); + this.blockManager = blockManager; + this.serializerManager = serializerManager; + this.fileBufferSize = fileBufferSize; + } + + @Override + public void write() throws IOException { + boolean allBeWritten = writeToPMem(); + if (!allBeWritten) { + sorted_logger.debug("No more PMEM space available. Write left spills to disk."); + writeToDisk(); + } + } + + /** + * @return if all records have been write to PMem, return true. Otherwise, return false. + * @throws IOException + */ + private boolean writeToPMem() throws IOException { + while (sortedIterator.hasNext()) { + sortedIterator.loadNext(); + final Object baseObject = sortedIterator.getBaseObject(); + final long baseOffset = sortedIterator.getBaseOffset(); + int curRecLen = sortedIterator.getRecordLength(); + long curPrefix = sortedIterator.getKeyPrefix(); + if (needNewPMemPage(curRecLen)) { + currentPMemPage = allocatePMemPage(); + } + if (currentPMemPage != null) { + long pageBaseOffset = currentPMemPage.getBaseOffset(); + long curPMemOffset = pageBaseOffset + currentOffsetInPage; + writeRecordToPMem(baseObject, baseOffset, curRecLen, curPrefix, curPMemOffset); + currentNumOfRecordsInPage ++; + pageNumOfRecMap.put(currentPMemPage, currentNumOfRecordsInPage); + numRecords ++; + } else { + //No more PMem space available, current loaded record can't be written to PMem. + return false; + } + } + //All records have been written to PMem. + return true; + } + + private void writeToDisk() throws IOException{ + int numOfRecLeft = numberOfRecordsToWritten - numRecordsOnPMem; + if (diskSpillWriter == null) { + diskSpillWriter = new UnsafeSorterSpillWriter( + blockManager, + fileBufferSize, + sortedIterator, + numOfRecLeft, + serializerManager, + writeMetrics, + taskMetrics); + } + diskSpillWriter.write(true); + sorted_logger.info("Num of rec {}; Num of rec written to PMem {}; still {} records left; num of rec written to disk {}.", + sortedIterator.getNumRecords(), + numRecordsOnPMem, + numOfRecLeft, + diskSpillWriter.recordsSpilled()); + } + + private boolean needNewPMemPage(int nextRecLen) { + if (allocatedPMemPages.isEmpty()) { + return true; + } + else { + long pageBaseOffset = currentPMemPage.getBaseOffset(); + long leftLenInCurPage = currentPMemPage.size() - currentOffsetInPage; + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + long recSizeRequired = uaoSize + Long.BYTES + nextRecLen; + if (leftLenInCurPage < recSizeRequired) { + return true; + } + } + return false; + } + + private void writeRecordToPMem(Object baseObject, long baseOffset, int recLength, long prefix, long pMemOffset){ + Platform.putInt( + null, + pMemOffset, + recLength); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + long currentOffset = pMemOffset + uaoSize; + Platform.putLong( + null, + currentOffset, + prefix); + currentOffset += Long.BYTES; + PersistentMemoryPlatform.copyMemory( + baseObject, + baseOffset, + null, + currentOffset, + recLength, + pMemClflushEnabled); + numRecordsOnPMem ++; + currentOffsetInPage += uaoSize + Long.BYTES + recLength; + } + + protected MemoryBlock allocatePMemPage() throws IOException{ + currentPMemPage = super.allocatePMemPage(); + currentOffsetInPage = 0; + currentNumOfRecordsInPage = 0; + return currentPMemPage; + } + + @Override + public UnsafeSorterIterator getSpillReader() throws IOException { + return new SortedPMemPageSpillReader(); + } + + @Override + public void clearAll() { + freeAllPMemPages(); + if (diskSpillWriter != null) { + diskSpillWriter.clearAll(); + } + } + + public int recordsSpilled() { + int recordsSpilledOnDisk = 0; + if (diskSpillWriter != null) { + recordsSpilledOnDisk = diskSpillWriter.recordsSpilled(); + } + return numRecordsOnPMem + recordsSpilledOnDisk; + } + + private class SortedPMemPageSpillReader extends UnsafeSorterIterator { + private final Logger sorted_reader_logger = LoggerFactory.getLogger(SortedPMemPageSpillReader.class); + private MemoryBlock curPage = null; + private int curPageIdx = -1; + private int curOffsetInPage = 0; + private int curNumOfRecInPage = 0; + private int curNumOfRec = 0; + private Object baseObject = null; + private long baseOffset = 0; + private int recordLength; + private long keyPrefix; + private UnsafeSorterIterator diskSpillReader; + private int numRecordsOnDisk = 0; + + public SortedPMemPageSpillReader() throws IOException{ + if (diskSpillWriter != null) { + diskSpillReader = diskSpillWriter.getSpillReader(); + numRecordsOnDisk = diskSpillWriter.recordsSpilled(); + } + } + @Override + public boolean hasNext() { + return curNumOfRec < numRecordsOnPMem + numRecordsOnDisk; + } + @Override + public void loadNext() throws IOException { + if(curNumOfRec < numRecordsOnPMem) { + loadNextOnPMem(); + } else { + loadNextOnDisk(); + } + } + + private void loadNextOnPMem() throws IOException { + if (curPage == null || curNumOfRecInPage >= pageNumOfRecMap.get(curPage)) { + moveToNextPMemPage(); + } + long curPageBaseOffset = curPage.getBaseOffset(); + recordLength = UnsafeAlignedOffset.getSize(null, curPageBaseOffset + curOffsetInPage); + curOffsetInPage += UnsafeAlignedOffset.getUaoSize(); + keyPrefix = Platform.getLong(null, curPageBaseOffset + curOffsetInPage); + curOffsetInPage += Long.BYTES; + baseOffset = curPageBaseOffset + curOffsetInPage; + curOffsetInPage += recordLength; + curNumOfRecInPage ++; + curNumOfRec ++; + } + + private void loadNextOnDisk() throws IOException { + if (diskSpillReader != null && diskSpillReader.hasNext()) { + diskSpillReader.loadNext(); + baseObject = diskSpillReader.getBaseObject(); + baseOffset = diskSpillReader.getBaseOffset(); + recordLength = diskSpillReader.getRecordLength(); + keyPrefix = diskSpillReader.getKeyPrefix(); + curNumOfRec ++; + } + } + + private void moveToNextPMemPage() { + curPageIdx++; + curPage = allocatedPMemPages.get(curPageIdx); + curOffsetInPage = 0; + curNumOfRecInPage = 0; + } + + @Override + public Object getBaseObject() { + return baseObject; + } + + @Override + public long getBaseOffset() { + return baseOffset; + } + + @Override + public int getRecordLength() { + return recordLength; + } + + @Override + public long getKeyPrefix() { + return keyPrefix; + } + + @Override + public int getNumRecords() { + return numRecordsOnPMem + numRecordsOnDisk; + } + } +} diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/SpillWriterForUnsafeSorter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/SpillWriterForUnsafeSorter.java new file mode 100644 index 00000000..06a670e8 --- /dev/null +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/SpillWriterForUnsafeSorter.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.io.IOException; + +public interface SpillWriterForUnsafeSorter { + //write down all the spills. + public void write() throws IOException; + + //get reader for the spill maintained by this writer. + public UnsafeSorterIterator getSpillReader() throws IOException; + + //clear all acquired resource after read is done. + public void clearAll(); + + //get spilled record number. + public int recordsSpilled(); +} diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 6b587ac0..9ecb6e4e 100644 --- a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -25,20 +25,17 @@ import java.util.function.Supplier; import com.google.common.annotations.VisibleForTesting; - -import com.intel.oap.common.storage.stream.ChunkInputStream; -import com.intel.oap.common.storage.stream.DataStore; - +import org.apache.spark.SparkEnv; +import org.apache.spark.internal.config.package$; +import org.apache.spark.memory.SparkOutOfMemoryError; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.SparkEnv; +import org.apache.spark.internal.config.package$; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.internal.config.package$; import org.apache.spark.memory.MemoryConsumer; -import org.apache.spark.memory.PMemManagerInitializer; -import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TooLargePageException; import org.apache.spark.serializer.SerializerManager; @@ -80,7 +77,12 @@ public final class UnsafeExternalSorter extends MemoryConsumer { * Force this sorter to spill when there are this many elements in memory. */ private final int numElementsForSpillThreshold; - + private final boolean spillToPMemEnabled = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get( + package$.MODULE$.MEMORY_SPILL_PMEM_ENABLED()); + /** + * spillWriterType + */ + private String spillWriterType = null; /** * Memory pages that hold the records being sorted. The pages in this list are freed when * spilling, although in principle we could recycle these pages across spills (on the other hand, @@ -89,7 +91,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { */ private final LinkedList allocatedPages = new LinkedList<>(); - private final LinkedList spillWriters = new LinkedList<>(); + private final LinkedList spillWriters = new LinkedList<>(); // These variables are reset after spilling: @Nullable private volatile UnsafeInMemorySorter inMemSorter; @@ -159,7 +161,12 @@ private UnsafeExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 this.fileBufferSizeBytes = 32 * 1024; - + SparkEnv sparkEnv = SparkEnv.get(); + if (sparkEnv != null && sparkEnv.conf() != null){ + this.spillWriterType = sparkEnv.conf().get(package$.MODULE$.USAFE_EXTERNAL_SORTER_SPILL_WRITE_TYPE()); + } else { + this.spillWriterType = PMemSpillWriterType.WRITE_SORTED_RECORDS_TO_PMEM.toString(); + } if (existingInMemorySorter == null) { RecordComparator comparator = null; if (recordComparatorSupplier != null) { @@ -214,34 +221,52 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { } logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", - Thread.currentThread().getId(), - Utils.bytesToString(getMemoryUsage()), - spillWriters.size(), - spillWriters.size() > 1 ? " times" : " time"); - + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spillWriters.size() , + spillWriters.size() > 1 ? " times" : " time"); ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); - - final UnsafeSorterSpillWriter spillWriter = - new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, - inMemSorter.numRecords()); - spillWriters.add(spillWriter); - spillIterator(inMemSorter.getSortedIterator(), spillWriter); - + // Sorting records or not will be handled by different spill writer, here null is given instead. + spillWithWriter(null, inMemSorter.numRecords(), writeMetrics, false); final long spillSize = freeMemory(); + inMemSorter.reset(); // Note that this is more-or-less going to be a multiple of the page size, so wasted space in // pages will currently be counted as memory spilled even though that space isn't actually // written to disk. This also counts the space needed to store the sorter's pointer array. - inMemSorter.reset(); // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the // records. Otherwise, if the task is over allocated memory, then without freeing the memory // pages, we might not be able to get memory for the pointer array. - taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); totalSpillBytes += spillSize; return spillSize; } + //Todo: It's confusing to pass in ShuffleWriteMetrics here. Will reconsider and fix it later + public SpillWriterForUnsafeSorter spillWithWriter( + UnsafeSorterIterator sortedIterator, + int numberOfRecordsToWritten, + ShuffleWriteMetrics writeMetrics, + boolean isSorted) throws IOException { + PMemSpillWriterType writerType = PMemSpillWriterType.valueOf(spillWriterType); + logger.info("PMemSpillWriterType:{}",writerType.toString()); + final SpillWriterForUnsafeSorter spillWriter = PMemSpillWriterFactory.getSpillWriter( + writerType, + this, + sortedIterator, + numberOfRecordsToWritten, + serializerManager, + blockManager, + fileBufferSizeBytes, + writeMetrics, + taskContext.taskMetrics(), + spillToPMemEnabled, + isSorted); + spillWriter.write(); + spillWriters.add(spillWriter); + return spillWriter; + } + /** * Return the total memory usage of this sorter, including the data pages and the sorter's pointer * array. @@ -292,6 +317,9 @@ public int getNumberOfAllocatedPages() { return allocatedPages.size(); } + public LinkedList getAllocatedPages() { + return this.allocatedPages; + } /** * Free this sorter's data pages. * @@ -310,41 +338,12 @@ private long freeMemory() { return memoryFreed; } - /** - * Deletes any spill files created by this sorter. - */ - private void deleteSpillFiles() { - boolean pMemSpillEnabled = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get( - package$.MODULE$.MEMORY_SPILL_PMEM_ENABLED()); - if (pMemSpillEnabled == true) - deletePMemSpillFiles(); - else - deleteDiskSpillFiles(); - } - private void deleteDiskSpillFiles() { - while (!spillWriters.isEmpty()) { - File file = spillWriters.removeFirst().getFile(); - if (file != null && file.exists()) { - if (!file.delete()) { - logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); - } - } - } - } - - private void deletePMemSpillFiles() { - while (!spillWriters.isEmpty()) { - File file = spillWriters.removeFirst().getFile(); - try { - ChunkInputStream cis = ChunkInputStream.getChunkInputStreamInstance(file.toString(), - new DataStore(PMemManagerInitializer.getPMemManager(), - PMemManagerInitializer.getProperties())); - cis.free(); - } catch (IOException e) { - logger.debug(e.toString()); - } + private void freeSpills() { + for (SpillWriterForUnsafeSorter spillWriter : spillWriters) { + spillWriter.clearAll(); } + spillWriters.clear(); } /** @@ -352,7 +351,7 @@ private void deletePMemSpillFiles() { */ public void cleanupResources() { synchronized (this) { - deleteSpillFiles(); + freeSpills(); freeMemory(); if (inMemSorter != null) { inMemSorter.free(); @@ -470,7 +469,6 @@ public void insertKVRecord(Object keyBase, long keyOffset, int keyLen, pageCursor += keyLen; Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen); pageCursor += valueLen; - assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); } @@ -499,8 +497,8 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { } else { final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger( recordComparatorSupplier.get(), prefixComparator, spillWriters.size()); - for (UnsafeSorterSpillWriter spillWriter : spillWriters) { - spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager)); + for (SpillWriterForUnsafeSorter spillWriter: spillWriters) { + spillMerger.addSpillIfNotEmpty(spillWriter.getSpillReader()); } if (inMemSorter != null) { readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); @@ -509,23 +507,12 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { return spillMerger.getSortedIterator(); } } + public UnsafeInMemorySorter getInMemSorter() { return inMemSorter; } @VisibleForTesting boolean hasSpaceForAnotherRecord() { return inMemSorter.hasSpaceForAnotherRecord(); } - private static void spillIterator(UnsafeSorterIterator inMemIterator, - UnsafeSorterSpillWriter spillWriter) throws IOException { - while (inMemIterator.hasNext()) { - inMemIterator.loadNext(); - final Object baseObject = inMemIterator.getBaseObject(); - final long baseOffset = inMemIterator.getBaseOffset(); - final int recordLength = inMemIterator.getRecordLength(); - spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); - } - spillWriter.close(); - } - /** * An UnsafeSorterIterator that support spilling. */ @@ -552,19 +539,14 @@ public long spill() throws IOException { && numRecords > 0)) { return 0L; } - UnsafeInMemorySorter.SortedIterator inMemIterator = - ((UnsafeInMemorySorter.SortedIterator) upstream).clone(); - - ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); - // Iterate over the records that have not been returned and spill them. - final UnsafeSorterSpillWriter spillWriter = - new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); - spillIterator(inMemIterator, spillWriter); - spillWriters.add(spillWriter); - nextUpstream = spillWriter.getReader(serializerManager); - + ((UnsafeInMemorySorter.SortedIterator) upstream).clone(); + + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); long released = 0L; + SpillWriterForUnsafeSorter spillWriter = spillWithWriter(inMemIterator, numRecords, writeMetrics, true); + nextUpstream = spillWriter.getSpillReader(); + synchronized (UnsafeExternalSorter.this) { // release the pages except the one that is used. There can still be a caller that // is accessing the current record. We free this page in that caller's next loadNext() @@ -668,9 +650,9 @@ public UnsafeSorterIterator getIterator(int startIndex) throws IOException { } else { LinkedList queue = new LinkedList<>(); int i = 0; - for (UnsafeSorterSpillWriter spillWriter : spillWriters) { + for (SpillWriterForUnsafeSorter spillWriter : spillWriters) { if (i + spillWriter.recordsSpilled() > startIndex) { - UnsafeSorterIterator iter = spillWriter.getReader(serializerManager); + UnsafeSorterIterator iter = spillWriter.getSpillReader(); moveOver(iter, startIndex - i); queue.add(iter); } diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java new file mode 100644 index 00000000..b1c1d393 --- /dev/null +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import java.util.Comparator; +import java.util.LinkedList; + +import org.apache.avro.reflect.Nullable; + +import org.apache.spark.TaskContext; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.SparkOutOfMemoryError; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.collection.Sorter; + +/** + * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records + * alongside a user-defined prefix of the record's sorting key. When the underlying sort algorithm + * compares records, it will first compare the stored key prefixes; if the prefixes are not equal, + * then we do not need to traverse the record pointers to compare the actual records. Avoiding these + * random memory accesses improves cache hit rates. + */ +public final class UnsafeInMemorySorter { + + private static final class SortComparator implements Comparator { + + private final RecordComparator recordComparator; + private final PrefixComparator prefixComparator; + private final TaskMemoryManager memoryManager; + + SortComparator( + RecordComparator recordComparator, + PrefixComparator prefixComparator, + TaskMemoryManager memoryManager) { + this.recordComparator = recordComparator; + this.prefixComparator = prefixComparator; + this.memoryManager = memoryManager; + } + + @Override + public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { + final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + if (prefixComparisonResult == 0) { + final Object baseObject1 = memoryManager.getPage(r1.recordPointer); + final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize; + final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize); + final Object baseObject2 = memoryManager.getPage(r2.recordPointer); + final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize; + final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize); + return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2, + baseOffset2, baseLength2); + } else { + return prefixComparisonResult; + } + } + } + + private final MemoryConsumer consumer; + private final TaskMemoryManager memoryManager; + @Nullable + private final Comparator sortComparator; + + /** + * If non-null, specifies the radix sort parameters and that radix sort will be used. + */ + @Nullable + private final PrefixComparators.RadixSortSupport radixSortSupport; + + /** + * Within this buffer, position {@code 2 * i} holds a pointer to the record at + * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + * + * Only part of the array will be used to store the pointers, the rest part is preserved as + * temporary buffer for sorting. + */ + private LongArray array; + + /** + * The position in the sort buffer where new records can be inserted. + */ + private int pos = 0; + + /** + * If sorting with radix sort, specifies the starting position in the sort buffer where records + * with non-null prefixes are kept. Positions [0..nullBoundaryPos) will contain null-prefixed + * records, and positions [nullBoundaryPos..pos) non-null prefixed records. This lets us avoid + * radix sorting over null values. + */ + private int nullBoundaryPos = 0; + + /* + * How many records could be inserted, because part of the array should be left for sorting. + */ + private int usableCapacity = 0; + + private long initialSize; + + private long totalSortTimeNanos = 0L; + + public UnsafeInMemorySorter( + final MemoryConsumer consumer, + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + int initialSize, + boolean canUseRadixSort) { + this(consumer, memoryManager, recordComparator, prefixComparator, + consumer.allocateArray(initialSize * 2L), canUseRadixSort); + } + + public UnsafeInMemorySorter( + final MemoryConsumer consumer, + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + LongArray array, + boolean canUseRadixSort) { + this.consumer = consumer; + this.memoryManager = memoryManager; + this.initialSize = array.size(); + if (recordComparator != null) { + this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) { + this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator; + } else { + this.radixSortSupport = null; + } + } else { + this.sortComparator = null; + this.radixSortSupport = null; + } + this.array = array; + this.usableCapacity = getUsableCapacity(); + } + + private int getUsableCapacity() { + // Radix sort requires same amount of used memory as buffer, Tim sort requires + // half of the used memory as buffer. + return (int) (array.size() / (radixSortSupport != null ? 2 : 1.5)); + } + + /** + * Free the memory used by pointer array. + */ + public void freeWithoutLongArray() { + if (consumer != null) { + array = null; + } + } + + public void free() { + if (consumer != null) { + if (array != null) { + consumer.freeArray(array); + } + array = null; + } + } + + public void resetWithoutLongArrray() { + if (consumer != null) { + // the call to consumer.allocateArray may trigger a spill which in turn access this instance + // and eventually re-enter this method and try to free the array again. by setting the array + // to null and its length to 0 we effectively make the spill code-path a no-op. setting the + // array to null also indicates that it has already been de-allocated which prevents a double + // de-allocation in free(). + array = null; + usableCapacity = 0; + pos = 0; + nullBoundaryPos = 0; + array = consumer.allocateArray(initialSize); + usableCapacity = getUsableCapacity(); + } + pos = 0; + nullBoundaryPos = 0; + } + + public void reset() { + if (consumer != null) { + consumer.freeArray(array); + // the call to consumer.allocateArray may trigger a spill which in turn access this instance + // and eventually re-enter this method and try to free the array again. by setting the array + // to null and its length to 0 we effectively make the spill code-path a no-op. setting the + // array to null also indicates that it has already been de-allocated which prevents a double + // de-allocation in free(). + array = null; + usableCapacity = 0; + pos = 0; + nullBoundaryPos = 0; + array = consumer.allocateArray(initialSize); + usableCapacity = getUsableCapacity(); + } + pos = 0; + nullBoundaryPos = 0; + } + + /** + * @return the number of records that have been inserted into this sorter. + */ + public int numRecords() { + return pos / 2; + } + + /** + * @return the total amount of time spent sorting data (in-memory only). + */ + public long getSortTimeNanos() { + return totalSortTimeNanos; + } + + public long getMemoryUsage() { + if (array == null) { + return 0L; + } + + return array.size() * 8; + } + + public LongArray getSortedArray() { + getSortedIterator(); + return array; + } + + public LongArray getArray() { + return array; + } + + public boolean hasSpaceForAnotherRecord() { + return pos + 1 < usableCapacity; + } + + public void expandPointerArray(LongArray newArray) { + if (newArray.size() < array.size()) { + // checkstyle.off: RegexpSinglelineJava + throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); + // checkstyle.on: RegexpSinglelineJava + } + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + pos * 8L); + consumer.freeArray(array); + array = newArray; + usableCapacity = getUsableCapacity(); + } + + /** + * Inserts a record to be sorted. Assumes that the record pointer points to a record length + * stored as a uaoSize(4 or 8) bytes integer, followed by the record's bytes. + * + * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. + * @param keyPrefix a user-defined key prefix + */ + public void insertRecord(long recordPointer, long keyPrefix, boolean prefixIsNull) { + if (!hasSpaceForAnotherRecord()) { + throw new IllegalStateException("There is no space for new record"); + } + if (prefixIsNull && radixSortSupport != null) { + // Swap forward a non-null record to make room for this one at the beginning of the array. + array.set(pos, array.get(nullBoundaryPos)); + pos++; + array.set(pos, array.get(nullBoundaryPos + 1)); + pos++; + // Place this record in the vacated position. + array.set(nullBoundaryPos, recordPointer); + nullBoundaryPos++; + array.set(nullBoundaryPos, keyPrefix); + nullBoundaryPos++; + } else { + array.set(pos, recordPointer); + pos++; + array.set(pos, keyPrefix); + pos++; + } + } + + public final class SortedIterator extends UnsafeSorterIterator implements Cloneable { + + private final int numRecords; + private int position; + private int offset; + private Object baseObject; + private long baseOffset; + private long keyPrefix; + private long currentRecordPointer; + private int recordLength; + private long currentPageNumber; + private final TaskContext taskContext = TaskContext.get(); + + private SortedIterator(int numRecords, int offset) { + this.numRecords = numRecords; + this.position = 0; + this.offset = offset; + } + + public SortedIterator clone() { + SortedIterator iter = new SortedIterator(numRecords, offset); + iter.position = position; + iter.baseObject = baseObject; + iter.baseOffset = baseOffset; + iter.keyPrefix = keyPrefix; + iter.recordLength = recordLength; + iter.currentPageNumber = currentPageNumber; + return iter; + } + + @Override + public int getNumRecords() { + return numRecords; + } + + @Override + public boolean hasNext() { + return position / 2 < numRecords; + } + + @Override + public void loadNext() { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); + } + // This pointer points to a 4-byte record length, followed by the record's bytes + final long recordPointer = array.get(offset + position); + currentRecordPointer = recordPointer; + currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); + int uaoSize = UnsafeAlignedOffset.getUaoSize(); + baseObject = memoryManager.getPage(recordPointer); + // Skip over record length + baseOffset = memoryManager.getOffsetInPage(recordPointer) + uaoSize; + recordLength = UnsafeAlignedOffset.getSize(baseObject, baseOffset - uaoSize); + keyPrefix = array.get(offset + position + 1); + position += 2; + } + + @Override + public Object getBaseObject() { return baseObject; } + + @Override + public long getBaseOffset() { return baseOffset; } + + public long getCurrentPageNumber() { + return currentPageNumber; + } + + @Override + public int getRecordLength() { return recordLength; } + + @Override + public long getKeyPrefix() { return keyPrefix; } + + public int getPosition() { return position; } + + public int getOffset() { return offset; } + + public long getCurrentRecordPointer() { return currentRecordPointer; } + } + + /** + * Return an iterator over record pointers in sorted order. For efficiency, all calls to + * {@code next()} will return the same mutable object. + */ + public UnsafeSorterIterator getSortedIterator() { + int offset = 0; + long start = System.nanoTime(); + if (sortComparator != null) { + if (this.radixSortSupport != null) { + offset = RadixSort.sortKeyPrefixArray( + array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, + radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); + } else { + MemoryBlock unused = new MemoryBlock( + array.getBaseObject(), + array.getBaseOffset() + pos * 8L, + (array.size() - pos) * 8L); + LongArray buffer = new LongArray(unused); + Sorter sorter = + new Sorter<>(new UnsafeSortDataFormat(buffer)); + sorter.sort(array, 0, pos / 2, sortComparator); + } + } + totalSortTimeNanos += System.nanoTime() - start; + if (nullBoundaryPos > 0) { + assert radixSortSupport != null : "Nulls are only stored separately with radix sort"; + LinkedList queue = new LinkedList<>(); + + // The null order is either LAST or FIRST, regardless of sorting direction (ASC|DESC) + if (radixSortSupport.nullsFirst()) { + queue.add(new SortedIterator(nullBoundaryPos / 2, 0)); + queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset)); + } else { + queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset)); + queue.add(new SortedIterator(nullBoundaryPos / 2, 0)); + } + return new UnsafeExternalSorter.ChainedIterator(queue); + } else { + return new SortedIterator(pos / 2, offset); + } + } + + public LongArray getLongArray() { + return array; + } + + public TaskMemoryManager getTaskMemoryManager() { + return memoryManager; + } + +} diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterPMemSpillWriter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterPMemSpillWriter.java new file mode 100644 index 00000000..becab97d --- /dev/null +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterPMemSpillWriter.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.unsafe.memory.MemoryBlock; + +import java.io.IOException; +import java.util.LinkedList; + +public abstract class UnsafeSorterPMemSpillWriter implements SpillWriterForUnsafeSorter{ + /** + * the memConsumer used to allocate pmem pages + */ + protected UnsafeExternalSorter externalSorter; + + protected SortedIteratorForSpills sortedIterator; + + protected int numberOfRecordsToWritten = 0; + + protected TaskMemoryManager taskMemoryManager; + + //Todo: It's confusing to have ShuffleWriteMetrics here. will reconsider and fix it later. + protected ShuffleWriteMetrics writeMetrics; + + protected TaskMetrics taskMetrics; + + protected LinkedList allocatedPMemPages = new LinkedList(); + + //Page size in bytes. + private static long DEFAULT_PAGE_SIZE = 64*1024*1024; + + public UnsafeSorterPMemSpillWriter( + UnsafeExternalSorter externalSorter, + SortedIteratorForSpills sortedIterator, + int numberOfRecordsToWritten, + ShuffleWriteMetrics writeMetrics, + TaskMetrics taskMetrics) { + this.externalSorter = externalSorter; + this.taskMemoryManager = externalSorter.getTaskMemoryManager(); + this.sortedIterator = sortedIterator; + this.numberOfRecordsToWritten = numberOfRecordsToWritten; + this.writeMetrics = writeMetrics; + this.taskMetrics = taskMetrics; + } + + protected MemoryBlock allocatePMemPage() throws IOException { ; + return allocatePMemPage(DEFAULT_PAGE_SIZE); + } + + protected MemoryBlock allocatePMemPage(long size) { + MemoryBlock page = taskMemoryManager.allocatePage(size, externalSorter, true); + if (page != null) { + allocatedPMemPages.add(page); + } + return page; + } + + protected void freeAllPMemPages() { + for (MemoryBlock page : allocatedPMemPages) { + taskMemoryManager.freePage(page, externalSorter); + } + allocatedPMemPages.clear(); + } + public abstract UnsafeSorterIterator getSpillReader() throws IOException; +} diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 32c2d79a..25ca11fb 100644 --- a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -19,15 +19,13 @@ import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; -import com.intel.oap.common.storage.stream.ChunkInputStream; -import com.intel.oap.common.storage.stream.DataStore; import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; +import org.apache.spark.executor.TaskMetrics; import org.apache.spark.internal.config.package$; import org.apache.spark.internal.config.ConfigEntry; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.io.ReadAheadInputStream; -import org.apache.spark.memory.PMemManagerInitializer; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; import org.apache.spark.unsafe.Platform; @@ -38,26 +36,34 @@ * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description * of the file format). */ -public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable { +public class UnsafeSorterSpillReader extends UnsafeSorterIterator implements Closeable { public static final int MAX_BUFFER_SIZE_BYTES = 16777216; // 16 mb - private InputStream in; - private DataInputStream din; + protected InputStream in; + protected DataInputStream din; // Variables that change with every record read: - private int recordLength; - private long keyPrefix; - private int numRecords; - private int numRecordsRemaining; - - private byte[] arr = new byte[1024 * 1024]; - private Object baseObject = arr; - private final TaskContext taskContext = TaskContext.get(); + protected int recordLength; + protected long keyPrefix; + protected int numRecords; + protected int numRecordsRemaining; + + protected byte[] arr = new byte[1024 * 1024]; + protected Object baseObject = arr; + protected final TaskContext taskContext = TaskContext.get(); + protected final TaskMetrics taskMetrics; + + public UnsafeSorterSpillReader(TaskMetrics taskMetrics) { + this.taskMetrics = taskMetrics; + } public UnsafeSorterSpillReader( SerializerManager serializerManager, + TaskMetrics taskMetrics, File file, BlockId blockId) throws IOException { + assert (file.length() > 0); + this.taskMetrics = taskMetrics; final ConfigEntry bufferSizeConfigEntry = package$.MODULE$.UNSAFE_SORTER_SPILL_READER_BUFFER_SIZE(); // This value must be less than or equal to MAX_BUFFER_SIZE_BYTES. Cast to int is always safe. @@ -69,12 +75,8 @@ public UnsafeSorterSpillReader( final boolean readAheadEnabled = SparkEnv.get() != null && (boolean)SparkEnv.get().conf().get( package$.MODULE$.UNSAFE_SORTER_SPILL_READ_AHEAD_ENABLED()); - boolean pMemSpillEnabled = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get( - package$.MODULE$.MEMORY_SPILL_PMEM_ENABLED()); - final InputStream bs = pMemSpillEnabled ? ChunkInputStream.getChunkInputStreamInstance(file.toString(), - new DataStore(PMemManagerInitializer.getPMemManager(), - PMemManagerInitializer.getProperties())) : - new NioBufferedFileInputStream(file, bufferSizeBytes); + final InputStream bs = + new NioBufferedFileInputStream(file, bufferSizeBytes); try { if (readAheadEnabled) { this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs), diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index cc9ee498..7ee391d7 100644 --- a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -20,10 +20,10 @@ import java.io.File; import java.io.IOException; +import org.apache.spark.executor.TaskMetrics; import scala.Tuple2; import org.apache.spark.SparkConf; -import org.apache.spark.SparkEnv; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.DummySerializerInstance; @@ -33,60 +33,99 @@ import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.Platform; import org.apache.spark.internal.config.package$; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Spills a list of sorted records to disk. Spill files have the following format: * * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...] */ -public final class UnsafeSorterSpillWriter { +public class UnsafeSorterSpillWriter implements SpillWriterForUnsafeSorter{ + private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillWriter.class); - private final SparkConf conf = new SparkConf(); + protected final SparkConf conf = new SparkConf(); /** * The buffer size to use when writing the sorted records to an on-disk file, and * this space used by prefix + len + recordLength must be greater than 4 + 8 bytes. */ - private final int diskWriteBufferSize = + protected final int diskWriteBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer // data through a byte array. - private byte[] writeBuffer = new byte[diskWriteBufferSize]; + protected byte[] writeBuffer = new byte[diskWriteBufferSize]; - private final File file; - private final BlockId blockId; - private final int numRecordsToWrite; - private DiskBlockObjectWriter writer; - private int numRecordsSpilled = 0; + protected File file = null; + protected BlockId blockId = null; + protected int numRecordsToWrite = 0; + protected DiskBlockObjectWriter writer; + protected int numRecordsSpilled = 0; + protected UnsafeSorterIterator inMemIterator; + protected SerializerManager serializerManager; + protected TaskMetrics taskMetrics; + + public UnsafeSorterSpillWriter() {} public UnsafeSorterSpillWriter( BlockManager blockManager, int fileBufferSize, + UnsafeSorterIterator inMemIterator, + int numRecordsToWrite, + SerializerManager serializerManager, ShuffleWriteMetrics writeMetrics, - int numRecordsToWrite) throws IOException { + TaskMetrics taskMetrics) { final Tuple2 spilledFileInfo = blockManager.diskBlockManager().createTempLocalBlock(); this.file = spilledFileInfo._2(); this.blockId = spilledFileInfo._1(); this.numRecordsToWrite = numRecordsToWrite; + this.serializerManager = serializerManager; + this.taskMetrics = taskMetrics; + this.inMemIterator = inMemIterator; // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. // Our write path doesn't actually use this serializer (since we end up calling the `write()` // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work // around this, we pass a dummy no-op serializer. - boolean pMemSpillEnabled = SparkEnv.get() != null && (boolean) SparkEnv.get().conf().get( - package$.MODULE$.MEMORY_SPILL_PMEM_ENABLED()); - - writer = pMemSpillEnabled == true ? blockManager.getPMemWriter( - blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics) - : blockManager.getDiskWriter( - blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + writer = blockManager.getDiskWriter( + blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); // Write the number of records writeIntToBuffer(numRecordsToWrite, 0); writer.write(writeBuffer, 0, 4); } + public UnsafeSorterSpillWriter( + BlockManager blockManager, + int fileBufferSize, + ShuffleWriteMetrics writeMetrics, + int numRecordsToWrite) throws IOException { + this(blockManager, + fileBufferSize, + null, + numRecordsToWrite, + null, + writeMetrics, + null); + } + + public UnsafeSorterSpillWriter( + BlockManager blockManager, + int fileBufferSize, + UnsafeSorterIterator inMemIterator, + SerializerManager serializerManager, + ShuffleWriteMetrics writeMetrics, + TaskMetrics taskMetrics) throws IOException { + this(blockManager, + fileBufferSize, + inMemIterator, + inMemIterator.getNumRecords(), + serializerManager, + writeMetrics, + taskMetrics); + } + // Based on DataOutputStream.writeLong. private void writeLongToBuffer(long v, int offset) { writeBuffer[offset + 0] = (byte)(v >>> 56); @@ -100,7 +139,7 @@ private void writeLongToBuffer(long v, int offset) { } // Based on DataOutputStream.writeInt. - private void writeIntToBuffer(int v, int offset) { + protected void writeIntToBuffer(int v, int offset) { writeBuffer[offset + 0] = (byte)(v >>> 24); writeBuffer[offset + 1] = (byte)(v >>> 16); writeBuffer[offset + 2] = (byte)(v >>> 8); @@ -161,11 +200,50 @@ public File getFile() { return file; } - public UnsafeSorterSpillReader getReader(SerializerManager serializerManager) throws IOException { - return new UnsafeSorterSpillReader(serializerManager, file, blockId); - } - public int recordsSpilled() { return numRecordsSpilled; } + + @Override + public void write() throws IOException { + write(false); + } + + public void write(boolean alreadyLoad) throws IOException { + if (inMemIterator != null) { + if (alreadyLoad) { + final Object baseObject = inMemIterator.getBaseObject(); + final long baseOffset = inMemIterator.getBaseOffset(); + final int recordLength = inMemIterator.getRecordLength(); + write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); + } + while (inMemIterator.hasNext()) { + inMemIterator.loadNext(); + final Object baseObject = inMemIterator.getBaseObject(); + final long baseOffset = inMemIterator.getBaseOffset(); + final int recordLength = inMemIterator.getRecordLength(); + write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); + } + close(); + } + } + + public UnsafeSorterSpillReader getReader(SerializerManager serializerManager, + TaskMetrics taskMetrics) throws IOException { + return new UnsafeSorterSpillReader(serializerManager, taskMetrics, file, blockId); + } + + @Override + public UnsafeSorterIterator getSpillReader() throws IOException{ + return new UnsafeSorterSpillReader(serializerManager, taskMetrics, file, blockId); + } + + @Override + public void clearAll() { + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } } diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterStreamSpillReader.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterStreamSpillReader.java new file mode 100644 index 00000000..96438686 --- /dev/null +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterStreamSpillReader.java @@ -0,0 +1,57 @@ +package org.apache.spark.util.collection.unsafe.sort; + +import com.google.common.io.Closeables; +import com.intel.oap.common.storage.stream.ChunkInputStream; +import com.intel.oap.common.storage.stream.DataStore; +import org.apache.spark.SparkEnv; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.internal.config.ConfigEntry; +import org.apache.spark.internal.config.package$; +import org.apache.spark.io.ReadAheadInputStream; +import org.apache.spark.memory.PMemManagerInitializer; +import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.storage.BlockId; + +import java.io.DataInputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; + +public class UnsafeSorterStreamSpillReader extends UnsafeSorterSpillReader { + + protected final ChunkInputStream chunkInputStream; + public UnsafeSorterStreamSpillReader( + SerializerManager serializerManager, + TaskMetrics taskMetrics, + File file, + BlockId blockId) throws IOException { + super(taskMetrics); + final ConfigEntry bufferSizeConfigEntry = + package$.MODULE$.UNSAFE_SORTER_SPILL_READER_BUFFER_SIZE(); + // This value must be less than or equal to MAX_BUFFER_SIZE_BYTES. Cast to int is always safe. + final int DEFAULT_BUFFER_SIZE_BYTES = + ((Long) bufferSizeConfigEntry.defaultValue().get()).intValue(); + int bufferSizeBytes = SparkEnv.get() == null ? DEFAULT_BUFFER_SIZE_BYTES : + ((Long) SparkEnv.get().conf().get(bufferSizeConfigEntry)).intValue(); + final boolean readAheadEnabled = SparkEnv.get() != null && (boolean)SparkEnv.get().conf().get( + package$.MODULE$.UNSAFE_SORTER_SPILL_READ_AHEAD_ENABLED()); + chunkInputStream = + ChunkInputStream.getChunkInputStreamInstance(file.toString(), + new DataStore(PMemManagerInitializer.getPMemManager(), + PMemManagerInitializer.getProperties())); + final InputStream bs = chunkInputStream; + try { + if (readAheadEnabled) { + this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs), + bufferSizeBytes); + } else { + this.in = serializerManager.wrapStream(blockId, bs); + } + this.din = new DataInputStream(this.in); + numRecords = numRecordsRemaining = din.readInt(); + } catch (IOException e) { + Closeables.close(bs, /* swallowIOException = */ true); + throw e; + } + } +} diff --git a/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterStreamSpillWriter.java b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterStreamSpillWriter.java new file mode 100644 index 00000000..3676800c --- /dev/null +++ b/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterStreamSpillWriter.java @@ -0,0 +1,62 @@ +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.serializer.DummySerializerInstance; +import org.apache.spark.serializer.SerializerManager; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TempLocalBlockId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Tuple2; + +import java.io.File; +import java.io.IOException; + +public class UnsafeSorterStreamSpillWriter extends UnsafeSorterSpillWriter { + private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterStreamSpillWriter.class); + private UnsafeSorterStreamSpillReader reader; + public UnsafeSorterStreamSpillWriter( + BlockManager blockManager, + int fileBufferSize, + UnsafeSorterIterator inMemIterator, + int numRecordsToWrite, + SerializerManager serializerManager, + ShuffleWriteMetrics writeMetrics, + TaskMetrics taskMetrics) { + super(); + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempLocalBlock(); + this.file = spilledFileInfo._2(); + this.blockId = spilledFileInfo._1(); + this.numRecordsToWrite = numRecordsToWrite; + this.serializerManager = serializerManager; + this.taskMetrics = taskMetrics; + this.inMemIterator = inMemIterator; + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + writer = blockManager.getPMemWriter( + blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics); + // Write the number of records + writeIntToBuffer(numRecordsToWrite, 0); + writer.write(writeBuffer, 0, 4); + } + + @Override + public UnsafeSorterIterator getSpillReader() throws IOException { + reader = new UnsafeSorterStreamSpillReader(serializerManager, taskMetrics, file, blockId); + return reader; + } + + @Override + public void clearAll() { + assert(reader != null); + try { + reader.chunkInputStream.free(); + } catch (IOException e) { + logger.debug(e.toString()); + } + } +} diff --git a/src/main/scala/org/apache/spark/internal/config/package.scala b/src/main/scala/org/apache/spark/internal/config/package.scala index 0efb323a..07661590 100644 --- a/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/src/main/scala/org/apache/spark/internal/config/package.scala @@ -28,6 +28,7 @@ import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils +import org.apache.spark.util.collection.unsafe.sort.PMemSpillWriterType import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader.MAX_BUFFER_SIZE_BYTES package object config { @@ -336,6 +337,57 @@ package object config { .checkValue(_ >= 0, "The off-heap memory size must not be negative") .createWithDefault(0) + val MEMORY_SPILL_PMEM_ENABLED = + ConfigBuilder("spark.memory.spill.pmem.enabled") + .doc("Set memory spill to PMem instead of disk.") + .booleanConf + .createWithDefault(true) + + val MEMORY_EXTENDED_PATH = + ConfigBuilder("spark.memory.extended.path") + .doc("intialize path for extended memory") + .stringConf + .createWithDefault("/mnt/pmem") + + private[spark] val MEMORY_EXTENDED_SIZE = ConfigBuilder("spark.memory.extended.size") + .doc("The absolute amount of memory which can be used for extended memory allocation.") + .version("1.6.0") + .bytesConf(ByteUnit.BYTE) + .checkValue(_ >= 0, "The extended memory size must not be negative") + .createWithDefault(64 * 1024 * 1024) + + val MEMORY_SPILL_PMEM_SORT_BACKGROUND = + ConfigBuilder("spark.memory.spill.pmem.sort.background") + .doc("Doing sort and dump pages to PMem concurrently") + .booleanConf + .createWithDefault(false) + + val MEMORY_SPILL_PMEM_CLFLUSH_ENABLED = + ConfigBuilder("spark.memory.spill.pmem.clflush.enabled") + .doc("Enable clflush when copy to PMEM.") + .booleanConf + .createWithDefault(false) + + val PMEM_PROPERTY_FILE = + ConfigBuilder("spark.memory.spill.pmem.config.file") + .doc("A config file used to config Intel PMem settings for memory extension.") + .stringConf + .createWithDefault("pmem.properties") + + val USAFE_EXTERNAL_SORTER_SPILL_WRITE_TYPE = ConfigBuilder("spark.unsafe.sort.spillwriter.type") + .doc("The spill writer type for UnsafeExteranlSorter") + .stringConf + .createWithDefault(PMemSpillWriterType.WRITE_SORTED_RECORDS_TO_PMEM.toString()) + + private[spark] val MEMORY_SPILL_PMEM_READ_BUFFERSIZE = + ConfigBuilder("spark.memory.spill.pmem.readBufferSize") + .doc("The buffer size, in bytes, to use when reading records from PMem.") + .bytesConf(ByteUnit.BYTE) + .checkValue(v => v > 12 && v <= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH, + s"The buffer size must be greater than 12 and less than or equal to " + + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") + .createWithDefaultString("8m") + private[spark] val MEMORY_STORAGE_FRACTION = ConfigBuilder("spark.memory.storageFraction") .doc("Amount of storage memory immune to eviction, expressed as a fraction of the " + "size of the region set aside by spark.memory.fraction. The higher this is, the " + @@ -356,18 +408,6 @@ package object config { .doubleConf .createWithDefault(0.6) - val MEMORY_SPILL_PMEM_ENABLED = - ConfigBuilder("spark.memory.spill.pmem.enabled") - .doc("Set memory spill to PMem instead of disk.") - .booleanConf - .createWithDefault(false) - - val PMEM_PROPERTY_FILE = - ConfigBuilder("spark.memory.spill.pmem.config.file") - .doc("A config file used to config Intel PMem settings for memory extension.") - .stringConf - .createWithDefault("pmem.properties") - private[spark] val STORAGE_SAFETY_FRACTION = ConfigBuilder("spark.storage.safetyFraction") .version("1.1.0") .doubleConf diff --git a/src/main/scala/org/apache/spark/memory/ExtendedMemoryPool.scala b/src/main/scala/org/apache/spark/memory/ExtendedMemoryPool.scala new file mode 100644 index 00000000..50dc05b8 --- /dev/null +++ b/src/main/scala/org/apache/spark/memory/ExtendedMemoryPool.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable + +import org.apache.spark.internal.Logging + +private[memory] class ExtendedMemoryPool(lock: Object) extends MemoryPool(lock) with Logging { + + private[this] val poolName: String = "extended memory" + + /** + * Map from taskAttemptId -> memory consumption in bytes + */ + @GuardedBy("lock") + private val extendedMemoryForTask = new mutable.HashMap[Long, Long]() + + @GuardedBy("lock") + private[this] var _memoryUsed: Long = 0L + + override def memoryUsed: Long = lock.synchronized { + _memoryUsed + } + /** + * Returns the memory consumption, in bytes, for the given task. + */ + def getMemoryUsageForTask(taskAttemptId: Long): Long = lock.synchronized { + extendedMemoryForTask.getOrElse(taskAttemptId, 0L) + } + + /** + * Try to acquire up to `numBytes` of extended memory for the given task and return the number + * of bytes obtained, or 0 if none can be allocated. + * + * @param numBytes number of bytes to acquire + * @param taskAttemptId the task attempt acquiring memory + * @return the number of bytes granted to the task. + */ + private[memory] def acquireMemory( + numBytes: Long, + taskAttemptId: Long): Long = lock.synchronized { + assert(numBytes > 0, s"invalid number of bytes requested: $numBytes") + + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to `acquireMemory` + if (!extendedMemoryForTask.contains(taskAttemptId)) { + extendedMemoryForTask(taskAttemptId) = 0L + // This will later cause waiting tasks to wake up and check numTasks again + lock.notifyAll() + } + + if (memoryFree >= numBytes) { + _memoryUsed += numBytes; + extendedMemoryForTask(taskAttemptId) += numBytes + return numBytes + } + 0L // Never reached + } + + /** + * Release `numBytes` of extended memory acquired by the given task. + */ + def releaseMemory(numBytes: Long, taskAttemptId: Long): Unit = lock.synchronized { + val curMem = extendedMemoryForTask.getOrElse(taskAttemptId, 0L) + val memoryToFree = if (curMem < numBytes) { + logWarning( + s"Internal error: release called on $numBytes bytes but task only has $curMem bytes " + + s"of memory from the $poolName pool") + curMem + } else { + numBytes + } + if (extendedMemoryForTask.contains(taskAttemptId)) { + extendedMemoryForTask(taskAttemptId) -= memoryToFree + if (extendedMemoryForTask(taskAttemptId) <= 0) { + extendedMemoryForTask.remove(taskAttemptId) + } + } + _memoryUsed -= memoryToFree + lock.notifyAll() // Notify waiters in acquireMemory() that memory has been freed + } + + /** + * Release all memory for the given task and mark it as inactive (e.g. when a task ends). + * + * @return the number of bytes freed. + */ + def releaseAllMemoryForTask(taskAttemptId: Long): Long = lock.synchronized { + val numBytesToFree = getMemoryUsageForTask(taskAttemptId) + releaseMemory(numBytesToFree, taskAttemptId) + numBytesToFree + } +} + diff --git a/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 7e8c58cf..7c385cd5 100644 --- a/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.memory import javax.annotation.concurrent.GuardedBy +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -26,7 +28,7 @@ import org.apache.spark.storage.BlockId import org.apache.spark.storage.memory.MemoryStore import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods -import org.apache.spark.unsafe.memory.MemoryAllocator +import org.apache.spark.unsafe.memory.{MemoryAllocator, MemoryBlock} /** * An abstract memory manager that enforces how memory is shared between execution and storage. @@ -55,6 +57,8 @@ private[spark] abstract class MemoryManager( protected val onHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.ON_HEAP) @GuardedBy("this") protected val offHeapExecutionMemoryPool = new ExecutionMemoryPool(this, MemoryMode.OFF_HEAP) + @GuardedBy("this") + protected val extendedMemoryPool = new ExtendedMemoryPool(this) onHeapStorageMemoryPool.incrementPoolSize(onHeapStorageMemory) onHeapExecutionMemoryPool.incrementPoolSize(onHeapExecutionMemory) @@ -71,7 +75,14 @@ private[spark] abstract class MemoryManager( protected[this] val pmemStorageMemory = (pmemInitialSize * pmemUsableRatio).toLong pmemStorageMemoryPool.incrementPoolSize(pmemStorageMemory) + protected[this] val extendedMemorySize = conf.get(MEMORY_EXTENDED_SIZE) + extendedMemoryPool.incrementPoolSize((extendedMemorySize * 0.9).toLong) + private[memory] var _pMemPages = new ArrayBuffer[MemoryBlock]; + + private[memory] def pMemPages: ArrayBuffer[MemoryBlock] = { + _pMemPages + } /** * Total available on heap memory for storage, in bytes. This amount can vary over time, * depending on the MemoryManager implementation. @@ -86,9 +97,9 @@ private[spark] abstract class MemoryManager( def maxOffHeapStorageMemory: Long /** - * Total available pmem memory for storage, in bytes. This amount can vary over time, - * depending on the MemoryManager implementation. - */ + * Total available pmem memory for storage, in bytes. This amount can vary over time, + * depending on the MemoryManager implementation. + */ def maxPMemStorageMemory: Long /** @@ -134,6 +145,18 @@ private[spark] abstract class MemoryManager( taskAttemptId: Long, memoryMode: MemoryMode): Long + /** + * try to acquire numBytes of extended memory for current task and return the number + * of number of bytes obtained, or 0 if non can be allocated. + * @param numBytes + * @param taskAttemptId + * @return + */ + private[memory] + def acquireExtendedMemory( + numBytes: Long, + taskAttemptId: Long): Long + /** * Release numBytes of execution memory belonging to the given task. */ @@ -185,6 +208,47 @@ private[spark] abstract class MemoryManager( releaseStorageMemory(numBytes, memoryMode) } + /** + * release extended memory of given task + * @param numBytes + * @param taskAttemptId + */ + def releaseExtendedMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized { + extendedMemoryPool.releaseMemory(numBytes, taskAttemptId) + } + + /** + * release all extended memory occupied by given task + * @param taskAttemptId + * @return + */ + def releaseAllExtendedMemoryForTask(taskAttemptId: Long): Long = synchronized { + extendedMemoryPool.releaseAllMemoryForTask(taskAttemptId) + } + + def addPMemPages(pMemPage: MemoryBlock): Unit = synchronized { + pMemPages.append(pMemPage); + } + + def freeAllPMemPages(): Unit = synchronized { + for (pMemPage <- pMemPages) { + extendedMemoryAllocator.free(pMemPage); + } + } + + /** + * @param size size of current page request + * @return PMem Page that suits for current page request + */ + def getUsablePMemPage(size : Long): MemoryBlock = synchronized { + for (pMemPage <- pMemPages) { + if (pMemPage.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER && + pMemPage.size() == size) { + return pMemPage; + } + } + return null; + } /** * Execution memory currently in use, in bytes. */ @@ -285,4 +349,6 @@ private[spark] abstract class MemoryManager( case MemoryMode.OFF_HEAP => MemoryAllocator.UNSAFE } } + + private[memory] final val extendedMemoryAllocator = MemoryAllocator.EXTENDED } diff --git a/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 6686082d..8bfcb360 100644 --- a/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -154,6 +154,15 @@ private[spark] class UnifiedMemoryManager( numBytes, taskAttemptId, maybeGrowExecutionPool, () => computeMaxExecutionPoolSize) } + override def acquireExtendedMemory(numBytes: Long, taskAttemptId: Long): Long = synchronized { + if (numBytes > extendedMemoryPool.memoryFree) { + logInfo(s"No PMem Space left, allocation fails.") + return 0; + } + extendedMemoryPool.acquireMemory(numBytes, taskAttemptId); + return numBytes + } + override def acquireStorageMemory( blockId: BlockId, numBytes: Long, diff --git a/src/main/scala/org/apache/spark/storage/BlockManager.scala b/src/main/scala/org/apache/spark/storage/BlockManager.scala index 81e7047d..ee2ef7e7 100644 --- a/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1289,7 +1289,7 @@ private[spark] class BlockManager( syncWrites, writeMetrics, blockId) } - /** + /** * A short circuited method to get a PMem writer that can write data directly to PMem. * The Block will be appended to the PMem stream specified by filename. Callers should handle * error cases. diff --git a/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala index b9bb4274..2910dbae 100644 --- a/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala +++ b/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -17,11 +17,12 @@ package org.apache.spark.memory +import com.intel.oap.common.unsafe.PersistentMemoryPlatform import javax.annotation.concurrent.GuardedBy - import scala.collection.mutable import org.apache.spark.SparkConf +import org.apache.spark.internal.config import org.apache.spark.storage.BlockId class TestMemoryManager(conf: SparkConf) @@ -33,11 +34,16 @@ class TestMemoryManager(conf: SparkConf) private var available = Long.MaxValue @GuardedBy("this") private val memoryForTask = mutable.HashMap[Long, Long]().withDefaultValue(0L) + private var extendedMemoryInitialized = false override private[memory] def acquireExecutionMemory( numBytes: Long, taskAttemptId: Long, memoryMode: MemoryMode): Long = synchronized { + if (conf.get(config.MEMORY_SPILL_PMEM_ENABLED) && extendedMemoryInitialized == false) { + PersistentMemoryPlatform.initialize("/dev/shm", 64 * 1024 * 1024, 0) + extendedMemoryInitialized = true + } require(numBytes >= 0) val acquired = { if (consequentOOM > 0) { @@ -75,6 +81,16 @@ class TestMemoryManager(conf: SparkConf) memoryForTask.remove(taskAttemptId).getOrElse(0L) } + override private[memory] def acquireExtendedMemory( + numBytes: Long, + taskAttemptId: Long): Long = synchronized { + if (extendedMemoryInitialized == false) { + PersistentMemoryPlatform.initialize("/dev/shm", 64 * 1024 * 1024, 0) + extendedMemoryInitialized = true + } + return extendedMemoryPool.acquireMemory(numBytes, taskAttemptId) + } + override private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = { memoryForTask.getOrElse(taskAttemptId, 0L) }