diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index 73577437ac506..5d905943a3aa7 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; /** * Simulates Hive's hashing function from Hive v1.2.1 @@ -38,12 +39,17 @@ public static int hashLong(long input) { return (int) ((input >>> 32) ^ input); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + public static int hashUnsafeBytesBlock(MemoryBlock mb) { + long lengthInBytes = mb.size(); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int result = 0; - for (int i = 0; i < lengthInBytes; i++) { - result = (result * 31) + (int) Platform.getByte(base, offset + i); + for (long i = 0; i < lengthInBytes; i++) { + result = (result * 31) + (int) mb.getByte(i); } return result; } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes)); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index aca6fca00c48b..54dcadf3a7754 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -187,7 +187,7 @@ public static void setMemory(long address, byte value, long size) { } public static void copyMemory( - Object src, long srcOffset, Object dst, long dstOffset, long length) { + Object src, long srcOffset, Object dst, long dstOffset, long length) { // Check if dstOffset is before or after srcOffset to determine if we should copy // forward or backwards. This is necessary in case src and dst overlap. if (dstOffset < srcOffset) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index a6b1f7a16d605..c334c9651cf6b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.array; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; public class ByteArrayMethods { @@ -48,6 +49,16 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15; private static final boolean unaligned = Platform.unaligned(); + /** + * MemoryBlock equality check for MemoryBlocks. + * @return true if the arrays are equal, false otherwise + */ + public static boolean arrayEqualsBlock( + MemoryBlock leftBase, long leftOffset, MemoryBlock rightBase, long rightOffset, final long length) { + return arrayEquals(leftBase.getBaseObject(), leftBase.getBaseOffset() + leftOffset, + rightBase.getBaseObject(), rightBase.getBaseOffset() + rightOffset, length); + } + /** * Optimized byte array equality check for byte arrays. * @return true if the arrays are equal, false otherwise @@ -56,7 +67,7 @@ public static boolean arrayEquals( Object leftBase, long leftOffset, Object rightBase, long rightOffset, final long length) { int i = 0; - // check if stars align and we can get both offsets to be aligned + // check if starts align and we can get both offsets to be aligned if ((leftOffset % 8) == (rightOffset % 8)) { while ((leftOffset + i) % 8 != 0 && i < length) { if (Platform.getByte(leftBase, leftOffset + i) != diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 2cd39bd60c2ac..b74d2de0691d5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -17,7 +17,6 @@ package org.apache.spark.unsafe.array; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; /** @@ -33,16 +32,12 @@ public final class LongArray { private static final long WIDTH = 8; private final MemoryBlock memory; - private final Object baseObj; - private final long baseOffset; private final long length; public LongArray(MemoryBlock memory) { assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size >= Integer.MAX_VALUE elements"; this.memory = memory; - this.baseObj = memory.getBaseObject(); - this.baseOffset = memory.getBaseOffset(); this.length = memory.size() / WIDTH; } @@ -51,11 +46,11 @@ public MemoryBlock memoryBlock() { } public Object getBaseObject() { - return baseObj; + return memory.getBaseObject(); } public long getBaseOffset() { - return baseOffset; + return memory.getBaseOffset(); } /** @@ -69,8 +64,8 @@ public long size() { * Fill this all with 0L. */ public void zeroOut() { - for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { - Platform.putLong(baseObj, off, 0); + for (long off = 0; off < length * WIDTH; off += WIDTH) { + memory.putLong(off, 0); } } @@ -80,7 +75,7 @@ public void zeroOut() { public void set(int index, long value) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - Platform.putLong(baseObj, baseOffset + index * WIDTH, value); + memory.putLong(index * WIDTH, value); } /** @@ -89,6 +84,6 @@ public void set(int index, long value) { public long get(int index) { assert index >= 0 : "index (" + index + ") should >= 0"; assert index < length : "index (" + index + ") should < length (" + length + ")"; - return Platform.getLong(baseObj, baseOffset + index * WIDTH); + return memory.getLong(index * WIDTH); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index d239de6083ad0..f372b19fac119 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -17,7 +17,9 @@ package org.apache.spark.unsafe.hash; -import org.apache.spark.unsafe.Platform; +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.memory.MemoryBlock; /** * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. @@ -49,49 +51,66 @@ public static int hashInt(int input, int seed) { } public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { - return hashUnsafeWords(base, offset, lengthInBytes, seed); + return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); } - public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWordsBlock(MemoryBlock base, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; - int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); + int h1 = hashBytesByIntBlock(base, seed); return fmix(h1, lengthInBytes); } - public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + return hashUnsafeWordsBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytesBlock(MemoryBlock base, int seed) { // This is not compatible with original and another implementations. // But remain it for backward compatibility for the components existing before 2.3. + int lengthInBytes = Ints.checkedCast(base.size()); assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); for (int i = lengthAligned; i < lengthInBytes; i++) { - int halfWord = Platform.getByte(base, offset + i); + int halfWord = base.getByte(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } return fmix(h1, lengthInBytes); } + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + public static int hashUnsafeBytes2(Object base, long offset, int lengthInBytes, int seed) { + return hashUnsafeBytes2Block(MemoryBlock.allocateFromObject(base, offset, lengthInBytes), seed); + } + + public static int hashUnsafeBytes2Block(MemoryBlock base, int seed) { // This is compatible with original and another implementations. // Use this method for new components after Spark 2.3. - assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthInBytes = Ints.checkedCast(base.size()); + assert (lengthInBytes >= 0) : "lengthInBytes cannot be negative"; int lengthAligned = lengthInBytes - lengthInBytes % 4; - int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + int h1 = hashBytesByIntBlock(base.subBlock(0, lengthAligned), seed); int k1 = 0; for (int i = lengthAligned, shift = 0; i < lengthInBytes; i++, shift += 8) { - k1 ^= (Platform.getByte(base, offset + i) & 0xFF) << shift; + k1 ^= (base.getByte(i) & 0xFF) << shift; } h1 ^= mixK1(k1); return fmix(h1, lengthInBytes); } - private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { + private static int hashBytesByIntBlock(MemoryBlock base, int seed) { + long lengthInBytes = base.size(); assert (lengthInBytes % 4 == 0); int h1 = seed; - for (int i = 0; i < lengthInBytes; i += 4) { - int halfWord = Platform.getInt(base, offset + i); + for (long i = 0; i < lengthInBytes; i += 4) { + int halfWord = base.getInt(i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java new file mode 100644 index 0000000000000..99a9868a49a79 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/ByteArrayMemoryBlock.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory with a byte array on Java heap. + */ +public final class ByteArrayMemoryBlock extends MemoryBlock { + + private final byte[] array; + + public ByteArrayMemoryBlock(byte[] obj, long offset, long size) { + super(obj, offset, size); + this.array = obj; + assert(offset + size <= Platform.BYTE_ARRAY_OFFSET + obj.length) : + "The sum of size " + size + " and offset " + offset + " should not be larger than " + + "the size of the given memory space " + (obj.length + Platform.BYTE_ARRAY_OFFSET); + } + + public ByteArrayMemoryBlock(long length) { + this(new byte[Ints.checkedCast(length)], Platform.BYTE_ARRAY_OFFSET, length); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new ByteArrayMemoryBlock(array, this.offset + offset, size); + } + + public byte[] getByteArray() { return array; } + + /** + * Creates a memory block pointing to the memory used by the byte array. + */ + public static ByteArrayMemoryBlock fromArray(final byte[] array) { + return new ByteArrayMemoryBlock(array, Platform.BYTE_ARRAY_OFFSET, array.length); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(array, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(array, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(array, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(array, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)]; + } + + @Override + public final void putByte(long offset, byte value) { + array[(int)(this.offset + offset - Platform.BYTE_ARRAY_OFFSET)] = value; + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(array, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(array, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(array, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(array, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(array, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(array, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(array, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(array, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 2733760dd19ef..acf28fd7ee59b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -58,7 +58,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { final long[] array = arrayReference.get(); if (array != null) { assert (array.length * 8L >= size); - MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -70,7 +70,7 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { } } long[] array = new long[numWords]; - MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + MemoryBlock memory = OnHeapMemoryBlock.fromArray(array, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -79,12 +79,13 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert (memory.obj != null) : + assert(memory instanceof OnHeapMemoryBlock); + assert (memory.getBaseObject() != null) : "baseObject was null; are you trying to use the on-heap allocator to free off-heap memory?"; - assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) - || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) + || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must first be freed via TMM.freePage(), not directly in allocator " + "free()"; @@ -94,12 +95,12 @@ public void free(MemoryBlock memory) { } // Mark the page as freed (so we can detect double-frees). - memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to null out its reference to the long[] array. - long[] array = (long[]) memory.obj; - memory.setObjAndOffset(null, 0); + long[] array = ((OnHeapMemoryBlock)memory).getLongArray(); + memory.resetObjAndOffset(); long alignedSize = ((size + 7) / 8) * 8; if (shouldPool(alignedSize)) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java index 7b588681d9790..38315fb97b46a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -38,7 +38,7 @@ public interface MemoryAllocator { void free(MemoryBlock memory); - MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + UnsafeMemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); - MemoryAllocator HEAP = new HeapMemoryAllocator(); + HeapMemoryAllocator HEAP = new HeapMemoryAllocator(); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index c333857358d30..b086941108522 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -22,10 +22,10 @@ import org.apache.spark.unsafe.Platform; /** - * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + * A representation of a consecutive memory block in Spark. It defines the common interfaces + * for memory accessing and mutating. */ -public class MemoryBlock extends MemoryLocation { - +public abstract class MemoryBlock { /** Special `pageNumber` value for pages which were not allocated by TaskMemoryManagers */ public static final int NO_PAGE_NUMBER = -1; @@ -45,38 +45,163 @@ public class MemoryBlock extends MemoryLocation { */ public static final int FREED_IN_ALLOCATOR_PAGE_NUMBER = -3; - private final long length; + @Nullable + protected Object obj; + + protected long offset; + + protected long length; /** * Optional page number; used when this MemoryBlock represents a page allocated by a - * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, - * which lives in a different package. + * TaskMemoryManager. This field can be updated using setPageNumber method so that + * this can be modified by the TaskMemoryManager, which lives in a different package. */ - public int pageNumber = NO_PAGE_NUMBER; + private int pageNumber = NO_PAGE_NUMBER; - public MemoryBlock(@Nullable Object obj, long offset, long length) { - super(obj, offset); + protected MemoryBlock(@Nullable Object obj, long offset, long length) { + if (offset < 0 || length < 0) { + throw new IllegalArgumentException( + "Length " + length + " and offset " + offset + "must be non-negative"); + } + this.obj = obj; + this.offset = offset; this.length = length; } + protected MemoryBlock() { + this(null, 0, 0); + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } + + public void resetObjAndOffset() { + this.obj = null; + this.offset = 0; + } + /** * Returns the size of the memory block. */ - public long size() { + public final long size() { return length; } - /** - * Creates a memory block pointing to the memory used by the long array. - */ - public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + public final void setPageNumber(int pageNum) { + pageNumber = pageNum; + } + + public final int getPageNumber() { + return pageNumber; } /** * Fills the memory block with the specified byte value. */ - public void fill(byte value) { + public final void fill(byte value) { Platform.setMemory(obj, offset, length, value); } + + /** + * Instantiate MemoryBlock for given object type with new offset + */ + public final static MemoryBlock allocateFromObject(Object obj, long offset, long length) { + MemoryBlock mb = null; + if (obj instanceof byte[]) { + byte[] array = (byte[])obj; + mb = new ByteArrayMemoryBlock(array, offset, length); + } else if (obj instanceof long[]) { + long[] array = (long[])obj; + mb = new OnHeapMemoryBlock(array, offset, length); + } else if (obj == null) { + // we assume that to pass null pointer means off-heap + mb = new OffHeapMemoryBlock(offset, length); + } else { + throw new UnsupportedOperationException( + "Instantiate MemoryBlock for type " + obj.getClass() + " is not supported now"); + } + return mb; + } + + /** + * Just instantiate the sub-block with the same type of MemoryBlock with the new size and relative + * offset from the original offset. The data is not copied. + * If parameters are invalid, an exception is thrown. + */ + public abstract MemoryBlock subBlock(long offset, long size); + + protected void checkSubBlockRange(long offset, long size) { + if (offset < 0 || size < 0) { + throw new ArrayIndexOutOfBoundsException( + "Size " + size + " and offset " + offset + " must be non-negative"); + } + if (offset + size > length) { + throw new ArrayIndexOutOfBoundsException("The sum of size " + size + " and offset " + + offset + " should not be larger than the length " + length + " in the MemoryBlock"); + } + } + + /** + * getXXX/putXXX does not ensure guarantee behavior if the offset is invalid. e.g cause illegal + * memory access, throw an exception, or etc. + * getXXX/putXXX uses an index based on this.offset that includes the size of metadata such as + * JVM object header. The offset is 0-based and is expected as an logical offset in the memory + * block. + */ + public abstract int getInt(long offset); + + public abstract void putInt(long offset, int value); + + public abstract boolean getBoolean(long offset); + + public abstract void putBoolean(long offset, boolean value); + + public abstract byte getByte(long offset); + + public abstract void putByte(long offset, byte value); + + public abstract short getShort(long offset); + + public abstract void putShort(long offset, short value); + + public abstract long getLong(long offset); + + public abstract void putLong(long offset, long value); + + public abstract float getFloat(long offset); + + public abstract void putFloat(long offset, float value); + + public abstract double getDouble(long offset); + + public abstract void putDouble(long offset, double value); + + public static final void copyMemory( + MemoryBlock src, long srcOffset, MemoryBlock dst, long dstOffset, long length) { + assert(srcOffset + length <= src.length && dstOffset + length <= dst.length); + Platform.copyMemory(src.getBaseObject(), src.getBaseOffset() + srcOffset, + dst.getBaseObject(), dst.getBaseOffset() + dstOffset, length); + } + + public static final void copyMemory(MemoryBlock src, MemoryBlock dst, long length) { + assert(length <= src.length && length <= dst.length); + Platform.copyMemory(src.getBaseObject(), src.getBaseOffset(), + dst.getBaseObject(), dst.getBaseOffset(), length); + } + + public final void copyFrom(Object src, long srcOffset, long dstOffset, long length) { + assert(length <= this.length - srcOffset); + Platform.copyMemory(src, srcOffset, obj, offset + dstOffset, length); + } + + public final void writeTo(long srcOffset, Object dst, long dstOffset, long length) { + assert(length <= this.length - srcOffset); + Platform.copyMemory(obj, offset + srcOffset, dst, dstOffset, length); + } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java deleted file mode 100644 index 74ebc87dc978c..0000000000000 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import javax.annotation.Nullable; - -/** - * A memory location. Tracked either by a memory address (with off-heap allocation), - * or by an offset from a JVM object (in-heap allocation). - */ -public class MemoryLocation { - - @Nullable - Object obj; - - long offset; - - public MemoryLocation(@Nullable Object obj, long offset) { - this.obj = obj; - this.offset = offset; - } - - public MemoryLocation() { - this(null, 0); - } - - public void setObjAndOffset(Object newObj, long newOffset) { - this.obj = newObj; - this.offset = newOffset; - } - - public final Object getBaseObject() { - return obj; - } - - public final long getBaseOffset() { - return offset; - } -} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java new file mode 100644 index 0000000000000..f90f62bf21dcb --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OffHeapMemoryBlock.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; + +public class OffHeapMemoryBlock extends MemoryBlock { + static public final OffHeapMemoryBlock NULL = new OffHeapMemoryBlock(0, 0); + + public OffHeapMemoryBlock(long address, long size) { + super(null, address, size); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new OffHeapMemoryBlock(this.offset + offset, size); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(null, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(null, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(null, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(null, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return Platform.getByte(null, this.offset + offset); + } + + @Override + public final void putByte(long offset, byte value) { + Platform.putByte(null, this.offset + offset, value); + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(null, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(null, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(null, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(null, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(null, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(null, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(null, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(null, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java new file mode 100644 index 0000000000000..12f67c7bd593e --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/OnHeapMemoryBlock.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import com.google.common.primitives.Ints; + +import org.apache.spark.unsafe.Platform; + +/** + * A consecutive block of memory with a long array on Java heap. + */ +public final class OnHeapMemoryBlock extends MemoryBlock { + + private final long[] array; + + public OnHeapMemoryBlock(long[] obj, long offset, long size) { + super(obj, offset, size); + this.array = obj; + assert(offset + size <= obj.length * 8L + Platform.LONG_ARRAY_OFFSET) : + "The sum of size " + size + " and offset " + offset + " should not be larger than " + + "the size of the given memory space " + (obj.length * 8L + Platform.LONG_ARRAY_OFFSET); + } + + public OnHeapMemoryBlock(long size) { + this(new long[Ints.checkedCast((size + 7) / 8)], Platform.LONG_ARRAY_OFFSET, size); + } + + @Override + public MemoryBlock subBlock(long offset, long size) { + checkSubBlockRange(offset, size); + if (offset == 0 && size == this.size()) return this; + return new OnHeapMemoryBlock(array, this.offset + offset, size); + } + + public long[] getLongArray() { return array; } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static OnHeapMemoryBlock fromArray(final long[] array) { + return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); + } + + public static OnHeapMemoryBlock fromArray(final long[] array, long size) { + return new OnHeapMemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); + } + + @Override + public final int getInt(long offset) { + return Platform.getInt(array, this.offset + offset); + } + + @Override + public final void putInt(long offset, int value) { + Platform.putInt(array, this.offset + offset, value); + } + + @Override + public final boolean getBoolean(long offset) { + return Platform.getBoolean(array, this.offset + offset); + } + + @Override + public final void putBoolean(long offset, boolean value) { + Platform.putBoolean(array, this.offset + offset, value); + } + + @Override + public final byte getByte(long offset) { + return Platform.getByte(array, this.offset + offset); + } + + @Override + public final void putByte(long offset, byte value) { + Platform.putByte(array, this.offset + offset, value); + } + + @Override + public final short getShort(long offset) { + return Platform.getShort(array, this.offset + offset); + } + + @Override + public final void putShort(long offset, short value) { + Platform.putShort(array, this.offset + offset, value); + } + + @Override + public final long getLong(long offset) { + return Platform.getLong(array, this.offset + offset); + } + + @Override + public final void putLong(long offset, long value) { + Platform.putLong(array, this.offset + offset, value); + } + + @Override + public final float getFloat(long offset) { + return Platform.getFloat(array, this.offset + offset); + } + + @Override + public final void putFloat(long offset, float value) { + Platform.putFloat(array, this.offset + offset, value); + } + + @Override + public final double getDouble(long offset) { + return Platform.getDouble(array, this.offset + offset); + } + + @Override + public final void putDouble(long offset, double value) { + Platform.putDouble(array, this.offset + offset, value); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 4368fb615ba1e..5310bdf2779a9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -25,9 +25,9 @@ public class UnsafeMemoryAllocator implements MemoryAllocator { @Override - public MemoryBlock allocate(long size) throws OutOfMemoryError { + public OffHeapMemoryBlock allocate(long size) throws OutOfMemoryError { long address = Platform.allocateMemory(size); - MemoryBlock memory = new MemoryBlock(null, address, size); + OffHeapMemoryBlock memory = new OffHeapMemoryBlock(address, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); } @@ -36,22 +36,25 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { @Override public void free(MemoryBlock memory) { - assert (memory.obj == null) : - "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; - assert (memory.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert(memory instanceof OffHeapMemoryBlock) : + "UnsafeMemoryAllocator can only free OffHeapMemoryBlock."; + if (memory == OffHeapMemoryBlock.NULL) return; + assert (memory.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "page has already been freed"; - assert ((memory.pageNumber == MemoryBlock.NO_PAGE_NUMBER) - || (memory.pageNumber == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : + assert ((memory.getPageNumber() == MemoryBlock.NO_PAGE_NUMBER) + || (memory.getPageNumber() == MemoryBlock.FREED_IN_TMM_PAGE_NUMBER)) : "TMM-allocated pages must be freed via TMM.freePage(), not directly in allocator free()"; if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); } + Platform.freeMemory(memory.offset); + // As an additional layer of defense against use-after-free bugs, we mutate the // MemoryBlock to reset its pointer. - memory.offset = 0; + memory.resetObjAndOffset(); // Mark the page as freed (so we can detect double-frees). - memory.pageNumber = MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER; + memory.setPageNumber(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index b0d0c44823e68..302be6407d5db 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -30,9 +30,12 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import com.google.common.primitives.Ints; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import static org.apache.spark.unsafe.Platform.*; @@ -50,12 +53,13 @@ public final class UTF8String implements Comparable, Externalizable, // These are only updated by readExternal() or read() @Nonnull - private Object base; - private long offset; + private MemoryBlock base; + // While numBytes has the same value as base.size(), to keep as int avoids cast from long to int private int numBytes; - public Object getBaseObject() { return base; } - public long getBaseOffset() { return offset; } + public MemoryBlock getMemoryBlock() { return base; } + public Object getBaseObject() { return base.getBaseObject(); } + public long getBaseOffset() { return base.getBaseOffset(); } private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, @@ -77,7 +81,8 @@ public final class UTF8String implements Comparable, Externalizable, */ public static UTF8String fromBytes(byte[] bytes) { if (bytes != null) { - return new UTF8String(bytes, BYTE_ARRAY_OFFSET, bytes.length); + return new UTF8String( + new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET, bytes.length)); } else { return null; } @@ -90,19 +95,13 @@ public static UTF8String fromBytes(byte[] bytes) { */ public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { if (bytes != null) { - return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); + return new UTF8String( + new ByteArrayMemoryBlock(bytes, BYTE_ARRAY_OFFSET + offset, numBytes)); } else { return null; } } - /** - * Creates an UTF8String from given address (base and offset) and length. - */ - public static UTF8String fromAddress(Object base, long offset, int numBytes) { - return new UTF8String(base, offset, numBytes); - } - /** * Creates an UTF8String from String. */ @@ -119,16 +118,13 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - protected UTF8String(Object base, long offset, int numBytes) { + public UTF8String(MemoryBlock base) { this.base = base; - this.offset = offset; - this.numBytes = numBytes; + this.numBytes = Ints.checkedCast(base.size()); } // for serialization - public UTF8String() { - this(null, 0, 0); - } + public UTF8String() {} /** * Writes the content of this string into a memory address, identified by an object and an offset. @@ -136,7 +132,7 @@ public UTF8String() { * bytes in this string. */ public void writeToMemory(Object target, long targetOffset) { - Platform.copyMemory(base, offset, target, targetOffset, numBytes); + base.writeTo(0, target, targetOffset, numBytes); } public void writeTo(ByteBuffer buffer) { @@ -156,8 +152,9 @@ public void writeTo(ByteBuffer buffer) { */ @Nonnull public ByteBuffer getByteBuffer() { - if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { - final byte[] bytes = (byte[]) base; + long offset = base.getBaseOffset(); + if (base instanceof ByteArrayMemoryBlock && offset >= BYTE_ARRAY_OFFSET) { + final byte[] bytes = ((ByteArrayMemoryBlock) base).getByteArray(); // the offset includes an object header... this is only needed for unsafe copies final long arrayOffset = offset - BYTE_ARRAY_OFFSET; @@ -223,12 +220,12 @@ public long getPrefix() { long mask = 0; if (IS_LITTLE_ENDIAN) { if (numBytes >= 8) { - p = Platform.getLong(base, offset); + p = base.getLong(0); } else if (numBytes > 4) { - p = Platform.getLong(base, offset); + p = base.getLong(0); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = (long) Platform.getInt(base, offset); + p = (long) base.getInt(0); mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -237,12 +234,12 @@ public long getPrefix() { } else { // byteOrder == ByteOrder.BIG_ENDIAN if (numBytes >= 8) { - p = Platform.getLong(base, offset); + p = base.getLong(0); } else if (numBytes > 4) { - p = Platform.getLong(base, offset); + p = base.getLong(0); mask = (1L << (8 - numBytes) * 8) - 1; } else if (numBytes > 0) { - p = ((long) Platform.getInt(base, offset)) << 32; + p = ((long) base.getInt(0)) << 32; mask = (1L << (8 - numBytes) * 8) - 1; } else { p = 0; @@ -257,12 +254,13 @@ public long getPrefix() { */ public byte[] getBytes() { // avoid copy if `base` is `byte[]` - if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[] - && ((byte[]) base).length == numBytes) { - return (byte[]) base; + long offset = base.getBaseOffset(); + if (offset == BYTE_ARRAY_OFFSET && base instanceof ByteArrayMemoryBlock + && (((ByteArrayMemoryBlock) base).getByteArray()).length == numBytes) { + return ((ByteArrayMemoryBlock) base).getByteArray(); } else { byte[] bytes = new byte[numBytes]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); return bytes; } } @@ -292,7 +290,7 @@ public UTF8String substring(final int start, final int until) { if (i > j) { byte[] bytes = new byte[i - j]; - copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); + base.writeTo(j, bytes, BYTE_ARRAY_OFFSET, i - j); return fromBytes(bytes); } else { return EMPTY_UTF8; @@ -333,14 +331,14 @@ public boolean contains(final UTF8String substring) { * Returns the byte at position `i`. */ private byte getByte(int i) { - return Platform.getByte(base, offset + i); + return base.getByte(i); } private boolean matchAt(final UTF8String s, int pos) { if (s.numBytes + pos > numBytes || pos < 0) { return false; } - return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); + return ByteArrayMethods.arrayEqualsBlock(base, pos, s.base, 0, s.numBytes); } public boolean startsWith(final UTF8String prefix) { @@ -467,8 +465,7 @@ public int findInSet(UTF8String match) { for (int i = 0; i < numBytes; i++) { if (getByte(i) == (byte) ',') { if (i - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, - match.numBytes)) { + ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { return n; } lastComma = i; @@ -476,8 +473,7 @@ public int findInSet(UTF8String match) { } } if (numBytes - (lastComma + 1) == match.numBytes && - ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset, - match.numBytes)) { + ByteArrayMethods.arrayEqualsBlock(base, lastComma + 1, match.base, 0, match.numBytes)) { return n; } return 0; @@ -492,7 +488,7 @@ public int findInSet(UTF8String match) { private UTF8String copyUTF8String(int start, int end) { int len = end - start + 1; byte[] newBytes = new byte[len]; - copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); + base.writeTo(start, newBytes, BYTE_ARRAY_OFFSET, len); return UTF8String.fromBytes(newBytes); } @@ -639,8 +635,7 @@ public UTF8String reverse() { int i = 0; // position in byte while (i < numBytes) { int len = numBytesForFirstByte(getByte(i)); - copyMemory(this.base, this.offset + i, result, - BYTE_ARRAY_OFFSET + result.length - i - len, len); + base.writeTo(i, result, BYTE_ARRAY_OFFSET + result.length - i - len, len); i += len; } @@ -654,7 +649,7 @@ public UTF8String repeat(int times) { } byte[] newBytes = new byte[numBytes * times]; - copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, newBytes, BYTE_ARRAY_OFFSET, numBytes); int copied = 1; while (copied < times) { @@ -691,7 +686,7 @@ public int indexOf(UTF8String v, int start) { if (i + v.numBytes > numBytes) { return -1; } - if (ByteArrayMethods.arrayEquals(base, offset + i, v.base, v.offset, v.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, i, v.base, 0, v.numBytes)) { return c; } i += numBytesForFirstByte(getByte(i)); @@ -707,7 +702,7 @@ public int indexOf(UTF8String v, int start) { private int find(UTF8String str, int start) { assert (str.numBytes > 0); while (start <= numBytes - str.numBytes) { - if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { return start; } start += 1; @@ -721,7 +716,7 @@ private int find(UTF8String str, int start) { private int rfind(UTF8String str, int start) { assert (str.numBytes > 0); while (start >= 0) { - if (ByteArrayMethods.arrayEquals(base, offset + start, str.base, str.offset, str.numBytes)) { + if (ByteArrayMethods.arrayEqualsBlock(base, start, str.base, 0, str.numBytes)) { return start; } start -= 1; @@ -754,7 +749,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { return EMPTY_UTF8; } byte[] bytes = new byte[idx]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, idx); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, idx); return fromBytes(bytes); } else { @@ -774,7 +769,7 @@ public UTF8String subStringIndex(UTF8String delim, int count) { } int size = numBytes - delim.numBytes - idx; byte[] bytes = new byte[size]; - copyMemory(base, offset + idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); + base.writeTo(idx + delim.numBytes, bytes, BYTE_ARRAY_OFFSET, size); return fromBytes(bytes); } } @@ -797,15 +792,15 @@ public UTF8String rpad(int len, UTF8String pad) { UTF8String remain = pad.substring(0, spaces - padChars * count); byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; - copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); + base.writeTo(0, data, BYTE_ARRAY_OFFSET, this.numBytes); int offset = this.numBytes; int idx = 0; while (idx < count) { - copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); return UTF8String.fromBytes(data); } @@ -833,13 +828,13 @@ public UTF8String lpad(int len, UTF8String pad) { int offset = 0; int idx = 0; while (idx < count) { - copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); + pad.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++ idx; offset += pad.numBytes; } - copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); + remain.base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); offset += remain.numBytes; - copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); + base.writeTo(0, data, BYTE_ARRAY_OFFSET + offset, numBytes()); return UTF8String.fromBytes(data); } @@ -864,8 +859,8 @@ public static UTF8String concat(UTF8String... inputs) { int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; - copyMemory( - inputs[i].base, inputs[i].offset, + inputs[i].base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -904,8 +899,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; - copyMemory( - inputs[i].base, inputs[i].offset, + inputs[i].base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, len); offset += len; @@ -913,8 +908,8 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { j++; // Add separator if this is not the last input. if (j < numInputs) { - copyMemory( - separator.base, separator.offset, + separator.base.writeTo( + 0, result, BYTE_ARRAY_OFFSET + offset, separator.numBytes); offset += separator.numBytes; @@ -1188,7 +1183,7 @@ public UTF8String clone() { public UTF8String copy() { byte[] bytes = new byte[numBytes]; - copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + base.writeTo(0, bytes, BYTE_ARRAY_OFFSET, numBytes); return fromBytes(bytes); } @@ -1196,11 +1191,10 @@ public UTF8String copy() { public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); int wordMax = (len / 8) * 8; - long roffset = other.offset; - Object rbase = other.base; + MemoryBlock rbase = other.base; for (int i = 0; i < wordMax; i += 8) { - long left = getLong(base, offset + i); - long right = getLong(rbase, roffset + i); + long left = base.getLong(i); + long right = rbase.getLong(i); if (left != right) { if (IS_LITTLE_ENDIAN) { return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); @@ -1211,7 +1205,7 @@ public int compareTo(@Nonnull final UTF8String other) { } for (int i = wordMax; i < len; i++) { // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF); + int res = (getByte(i) & 0xFF) - (rbase.getByte(i) & 0xFF); if (res != 0) { return res; } @@ -1230,7 +1224,7 @@ public boolean equals(final Object other) { if (numBytes != o.numBytes) { return false; } - return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); + return ByteArrayMethods.arrayEqualsBlock(base, 0, o.base, 0, numBytes); } else { return false; } @@ -1286,8 +1280,8 @@ public int levenshteinDistance(UTF8String other) { num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) { cost = 1; } else { - cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base, - s.offset + i_bytes, num_bytes_j)) ? 0 : 1; + cost = (ByteArrayMethods.arrayEqualsBlock(t.base, j_bytes, s.base, + i_bytes, num_bytes_j)) ? 0 : 1; } d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost); } @@ -1302,7 +1296,7 @@ public int levenshteinDistance(UTF8String other) { @Override public int hashCode() { - return Murmur3_x86_32.hashUnsafeBytes(base, offset, numBytes, 42); + return Murmur3_x86_32.hashUnsafeBytesBlock(base,42); } /** @@ -1365,10 +1359,10 @@ public void writeExternal(ObjectOutput out) throws IOException { } public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - offset = BYTE_ARRAY_OFFSET; numBytes = in.readInt(); - base = new byte[numBytes]; - in.readFully((byte[]) base); + byte[] bytes = new byte[numBytes]; + in.readFully(bytes); + base = ByteArrayMemoryBlock.fromArray(bytes); } @Override @@ -1380,10 +1374,10 @@ public void write(Kryo kryo, Output out) { @Override public void read(Kryo kryo, Input in) { - this.offset = BYTE_ARRAY_OFFSET; - this.numBytes = in.readInt(); - this.base = new byte[numBytes]; - in.read((byte[]) base); + numBytes = in.readInt(); + byte[] bytes = new byte[numBytes]; + in.read(bytes); + base = ByteArrayMemoryBlock.fromArray(bytes); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 3ad9ac7b4de9c..583a148b3845d 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -81,7 +81,7 @@ public void freeingOnHeapMemoryBlockResetsBaseObjectAndOffset() { MemoryAllocator.HEAP.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); } @Test @@ -92,7 +92,7 @@ public void freeingOffHeapMemoryBlockResetsOffset() { MemoryAllocator.UNSAFE.free(block); Assert.assertNull(block.getBaseObject()); Assert.assertEquals(0, block.getBaseOffset()); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, block.getPageNumber()); } @Test(expected = AssertionError.class) diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index fb8e53b3348f3..8c2e98c2bfc54 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -20,14 +20,13 @@ import org.junit.Assert; import org.junit.Test; -import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; public class LongArraySuite { @Test public void basicTest() { - long[] bytes = new long[2]; - LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); + LongArray arr = new LongArray(new OnHeapMemoryBlock(16)); arr.set(0, 1L); arr.set(1, 2L); arr.set(1, 3L); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java index 6348a73bf3895..d7ed005db1891 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -70,6 +70,24 @@ public void testKnownBytesInputs() { Murmur3_x86_32.hashUnsafeBytes2(tes, Platform.BYTE_ARRAY_OFFSET, tes.length, 0)); } + @Test + public void testKnownWordsInputs() { + byte[] bytes = new byte[16]; + long offset = Platform.BYTE_ARRAY_OFFSET; + for (int i = 0; i < 16; i++) { + bytes[i] = 0; + } + Assert.assertEquals(-300363099, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + for (int i = 0; i < 16; i++) { + bytes[i] = -1; + } + Assert.assertEquals(-1210324667, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + for (int i = 0; i < 16; i++) { + bytes[i] = (byte)i; + } + Assert.assertEquals(-634919701, hasher.hashUnsafeWords(bytes, offset, 16, 42)); + } + @Test public void randomizedStressTest() { int size = 65536; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java new file mode 100644 index 0000000000000..47f05c928f2e5 --- /dev/null +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/memory/MemoryBlockSuite.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.Platform; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteOrder; + +import static org.hamcrest.core.StringContains.containsString; + +public class MemoryBlockSuite { + private static final boolean bigEndianPlatform = + ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + + private void check(MemoryBlock memory, Object obj, long offset, int length) { + memory.setPageNumber(1); + memory.fill((byte)-1); + memory.putBoolean(0, true); + memory.putByte(1, (byte)127); + memory.putShort(2, (short)257); + memory.putInt(4, 0x20000002); + memory.putLong(8, 0x1234567089ABCDEFL); + memory.putFloat(16, 1.0F); + memory.putLong(20, 0x1234567089ABCDEFL); + memory.putDouble(28, 2.0); + MemoryBlock.copyMemory(memory, 0L, memory, 36, 4); + int[] a = new int[2]; + a[0] = 0x12345678; + a[1] = 0x13579BDF; + memory.copyFrom(a, Platform.INT_ARRAY_OFFSET, 40, 8); + byte[] b = new byte[8]; + memory.writeTo(40, b, Platform.BYTE_ARRAY_OFFSET, 8); + + Assert.assertEquals(obj, memory.getBaseObject()); + Assert.assertEquals(offset, memory.getBaseOffset()); + Assert.assertEquals(length, memory.size()); + Assert.assertEquals(1, memory.getPageNumber()); + Assert.assertEquals(true, memory.getBoolean(0)); + Assert.assertEquals((byte)127, memory.getByte(1 )); + Assert.assertEquals((short)257, memory.getShort(2)); + Assert.assertEquals(0x20000002, memory.getInt(4)); + Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(8)); + Assert.assertEquals(1.0F, memory.getFloat(16), 0); + Assert.assertEquals(0x1234567089ABCDEFL, memory.getLong(20)); + Assert.assertEquals(2.0, memory.getDouble(28), 0); + Assert.assertEquals(true, memory.getBoolean(36)); + Assert.assertEquals((byte)127, memory.getByte(37 )); + Assert.assertEquals((short)257, memory.getShort(38)); + Assert.assertEquals(a[0], memory.getInt(40)); + Assert.assertEquals(a[1], memory.getInt(44)); + if (bigEndianPlatform) { + Assert.assertEquals(a[0], + ((int)b[0] & 0xff) << 24 | ((int)b[1] & 0xff) << 16 | + ((int)b[2] & 0xff) << 8 | ((int)b[3] & 0xff)); + Assert.assertEquals(a[1], + ((int)b[4] & 0xff) << 24 | ((int)b[5] & 0xff) << 16 | + ((int)b[6] & 0xff) << 8 | ((int)b[7] & 0xff)); + } else { + Assert.assertEquals(a[0], + ((int)b[3] & 0xff) << 24 | ((int)b[2] & 0xff) << 16 | + ((int)b[1] & 0xff) << 8 | ((int)b[0] & 0xff)); + Assert.assertEquals(a[1], + ((int)b[7] & 0xff) << 24 | ((int)b[6] & 0xff) << 16 | + ((int)b[5] & 0xff) << 8 | ((int)b[4] & 0xff)); + } + for (int i = 48; i < memory.size(); i++) { + Assert.assertEquals((byte) -1, memory.getByte(i)); + } + + assert(memory.subBlock(0, memory.size()) == memory); + + try { + memory.subBlock(-8, 8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("non-negative")); + } + + try { + memory.subBlock(0, -8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("non-negative")); + } + + try { + memory.subBlock(0, length + 8); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + try { + memory.subBlock(8, length - 4); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + + try { + memory.subBlock(length + 8, 4); + Assert.fail(); + } catch (Exception expected) { + Assert.assertThat(expected.getMessage(), containsString("should not be larger than")); + } + } + + @Test + public void ByteArrayMemoryBlockTest() { + byte[] obj = new byte[56]; + long offset = Platform.BYTE_ARRAY_OFFSET; + int length = obj.length; + + MemoryBlock memory = new ByteArrayMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + + memory = ByteArrayMemoryBlock.fromArray(obj); + check(memory, obj, offset, length); + + obj = new byte[112]; + memory = new ByteArrayMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + } + + @Test + public void OnHeapMemoryBlockTest() { + long[] obj = new long[7]; + long offset = Platform.LONG_ARRAY_OFFSET; + int length = obj.length * 8; + + MemoryBlock memory = new OnHeapMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + + memory = OnHeapMemoryBlock.fromArray(obj); + check(memory, obj, offset, length); + + obj = new long[14]; + memory = new OnHeapMemoryBlock(obj, offset, length); + check(memory, obj, offset, length); + } + + @Test + public void OffHeapArrayMemoryBlockTest() { + MemoryAllocator memoryAllocator = new UnsafeMemoryAllocator(); + MemoryBlock memory = memoryAllocator.allocate(56); + Object obj = memory.getBaseObject(); + long offset = memory.getBaseOffset(); + int length = 56; + + check(memory, obj, offset, length); + + long address = Platform.allocateMemory(112); + memory = new OffHeapMemoryBlock(address, length); + obj = memory.getBaseObject(); + offset = memory.getBaseOffset(); + check(memory, obj, offset, length); + } +} diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 9b303fa5bc6c5..b394091300d14 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -26,6 +26,9 @@ import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock; import org.junit.Test; import static org.junit.Assert.*; @@ -515,7 +518,8 @@ public void writeToOutputStreamUnderflow() throws IOException { final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { - UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i) + new UTF8String( + new ByteArrayMemoryBlock(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i)) .writeTo(outputStream); final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); @@ -530,7 +534,7 @@ public void writeToOutputStreamSlice() throws IOException { for (int i = 0; i < test.length; ++i) { for (int j = 0; j < test.length - i; ++j) { - UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET + i, j) + new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(i, j)) .writeTo(outputStream); assertArrayEquals(Arrays.copyOfRange(test, i, i + j), outputStream.toByteArray()); @@ -561,7 +565,7 @@ public void writeToOutputStreamOverflow() throws IOException { for (final long offset : offsets) { try { - fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length) + new UTF8String(ByteArrayMemoryBlock.fromArray(test).subBlock(offset, test.length)) .writeTo(outputStream); throw new IllegalStateException(Long.toString(offset)); @@ -588,26 +592,25 @@ public void writeToOutputStream() throws IOException { } @Test - public void writeToOutputStreamIntArray() throws IOException { + public void writeToOutputStreamLongArray() throws IOException { // verify that writes work on objects that are not byte arrays - final ByteBuffer buffer = StandardCharsets.UTF_8.encode("大千世界"); + final ByteBuffer buffer = StandardCharsets.UTF_8.encode("3千大千世界"); buffer.position(0); buffer.order(ByteOrder.nativeOrder()); final int length = buffer.limit(); - assertEquals(12, length); + assertEquals(16, length); - final int ints = length / 4; - final int[] array = new int[ints]; + final int longs = length / 8; + final long[] array = new long[longs]; - for (int i = 0; i < ints; ++i) { - array[i] = buffer.getInt(); + for (int i = 0; i < longs; ++i) { + array[i] = buffer.getLong(); } final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - fromAddress(array, Platform.INT_ARRAY_OFFSET, length) - .writeTo(outputStream); - assertEquals("大千世界", outputStream.toString("UTF-8")); + new UTF8String(OnHeapMemoryBlock.fromArray(array)).writeTo(outputStream); + assertEquals("3千大千世界", outputStream.toString("UTF-8")); } @Test diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index d07faf1da1248..8651a639c07f7 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -311,7 +311,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { // this could trigger spilling to free some pages. return allocatePage(size, consumer); } - page.pageNumber = pageNumber; + page.setPageNumber(pageNumber); pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); @@ -323,25 +323,25 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ public void freePage(MemoryBlock page, MemoryConsumer consumer) { - assert (page.pageNumber != MemoryBlock.NO_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.NO_PAGE_NUMBER) : "Called freePage() on memory that wasn't allocated with allocatePage()"; - assert (page.pageNumber != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert (page.pageNumber != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : + assert (page.getPageNumber() != MemoryBlock.FREED_IN_TMM_PAGE_NUMBER) : "Called freePage() on a memory block that has already been freed"; - assert(allocatedPages.get(page.pageNumber)); - pageTable[page.pageNumber] = null; + assert(allocatedPages.get(page.getPageNumber())); + pageTable[page.getPageNumber()] = null; synchronized (this) { - allocatedPages.clear(page.pageNumber); + allocatedPages.clear(page.getPageNumber()); } if (logger.isTraceEnabled()) { - logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + logger.trace("Freed page number {} ({} bytes)", page.getPageNumber(), page.size()); } long pageSize = page.size(); // Clear the page number before passing the block to the MemoryAllocator's free(). // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed // page has been inappropriately directly freed without calling TMM.freePage(). - page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); memoryManager.tungstenMemoryAllocator().free(page); releaseExecutionMemory(pageSize, consumer); } @@ -363,7 +363,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { // relative to the page's base offset; this relative offset will fit in 51 bits. offsetInPage -= page.getBaseOffset(); } - return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + return encodePageNumberAndOffset(page.getPageNumber(), offsetInPage); } @VisibleForTesting @@ -434,7 +434,7 @@ public long cleanUpAllAllocatedMemory() { for (MemoryBlock page : pageTable) { if (page != null) { logger.debug("unreleased page: " + page + " in task " + taskAttemptId); - page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + page.setPageNumber(MemoryBlock.FREED_IN_TMM_PAGE_NUMBER); memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index dc36809d8911f..8f49859746b89 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -20,7 +20,6 @@ import java.util.Comparator; import org.apache.spark.memory.MemoryConsumer; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; @@ -105,13 +104,7 @@ public void reset() { public void expandPointerArray(LongArray newArray) { assert(newArray.size() > array.size()); - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L - ); + MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -180,10 +173,7 @@ public ShuffleSorterIterator getSortedIterator() { PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX, PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false); } else { - MemoryBlock unused = new MemoryBlock( - array.getBaseObject(), - array.getBaseOffset() + pos * 8L, - (array.size() - pos) * 8L); + MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new ShuffleSortDataFormat(buffer)); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 717bdd79d47ef..254449e95443e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -17,8 +17,8 @@ package org.apache.spark.shuffle.sort; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; final class ShuffleSortDataFormat extends SortDataFormat { @@ -60,13 +60,8 @@ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { @Override public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { - Platform.copyMemory( - src.getBaseObject(), - src.getBaseOffset() + srcPos * 8L, - dst.getBaseObject(), - dst.getBaseOffset() + dstPos * 8L, - length * 8L - ); + MemoryBlock.copyMemory(src.memoryBlock(), srcPos * 8L, + dst.memoryBlock(),dstPos * 8L,length * 8L); } @Override diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 66118f454159b..4fc19b1721518 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -544,7 +544,7 @@ public long spill() throws IOException { // is accessing the current record. We free this page in that caller's next loadNext() // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.pageNumber != + if (!loaded || page.getPageNumber() != ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) { released += page.size(); freePage(page); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index b3c27d83da172..20a7a8b267438 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -26,7 +26,6 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.SparkOutOfMemoryError; import org.apache.spark.memory.TaskMemoryManager; -import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -216,12 +215,7 @@ public void expandPointerArray(LongArray newArray) { if (newArray.size() < array.size()) { throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); } - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L); + MemoryBlock.copyMemory(array.memoryBlock(), newArray.memoryBlock(), pos * 8L); consumer.freeArray(array); array = newArray; usableCapacity = getUsableCapacity(); @@ -348,10 +342,7 @@ public UnsafeSorterIterator getSortedIterator() { array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { - MemoryBlock unused = new MemoryBlock( - array.getBaseObject(), - array.getBaseOffset() + pos * 8L, - (array.size() - pos) * 8L); + MemoryBlock unused = array.memoryBlock().subBlock(pos * 8L, (array.size() - pos) * 8L); LongArray buffer = new LongArray(unused); Sorter sorter = new Sorter<>(new UnsafeSortDataFormat(buffer)); diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index a0664b30d6cc2..d7d2d0b012bd3 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -76,7 +76,7 @@ public void freeingPageSetsPageNumberToSpecialConstant() { final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); final MemoryBlock dataPage = manager.allocatePage(256, c); c.freePage(dataPage); - Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.pageNumber); + Assert.assertEquals(MemoryBlock.FREED_IN_ALLOCATOR_PAGE_NUMBER, dataPage.getPageNumber()); } @Test(expected = AssertionError.class) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 47173b89e91e2..3e56db5ea116a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark._ import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordPointerAndKeyPrefix, UnsafeSortDataFormat} class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { @@ -105,9 +105,8 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { // the form [150000000, 150000001, 150000002, ...., 300000000, 0, 1, 2, ..., 149999999] // that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi() val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i } - val buf = new LongArray(MemoryBlock.fromLongArray(ref)) - val tmp = new Array[Long](size/2) - val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(ref)) + val tmpBuf = new LongArray(new OnHeapMemoryBlock((size/2) * 8L)) new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort( buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index d5956ea32096a..ddf3740e76a7a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -27,7 +27,7 @@ import com.google.common.primitives.Ints import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom @@ -78,14 +78,14 @@ class RadixSortSuite extends SparkFunSuite with Logging { private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) - (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) + (ref.map(i => new JLong(i)), new LongArray(OnHeapMemoryBlock.fromArray(extended))) } private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) - (new LongArray(MemoryBlock.fromLongArray(ref)), - new LongArray(MemoryBlock.fromLongArray(extended))) + (new LongArray(OnHeapMemoryBlock.fromArray(ref)), + new LongArray(OnHeapMemoryBlock.fromArray(extended))) } private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { @@ -110,7 +110,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { } private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala index c78f61ac3ef71..d67e4819b161a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.feature.{HashingTF => OldHashingTF} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2} +import org.apache.spark.unsafe.hash.Murmur3_x86_32.{hashInt, hashLong, hashUnsafeBytes2Block} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashMap @@ -243,8 +243,7 @@ object FeatureHasher extends DefaultParamsReadable[FeatureHasher] { case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => - val utf8 = UTF8String.fromString(s) - hashUnsafeBytes2(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + hashUnsafeBytes2Block(UTF8String.fromString(s).getMemoryBlock, seed) case _ => throw new SparkException("FeatureHasher with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 8935c8496cdbb..7b73b286fb91c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -160,7 +160,7 @@ object HashingTF { case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) case s: String => val utf8 = UTF8String.fromString(s) - hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + hashUnsafeBytesBlock(utf8.getMemoryBlock(), seed) case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " + s"support type ${term.getClass.getCanonicalName} of input data.") } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index d18542b188f71..8546c28335536 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -27,6 +27,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -230,7 +231,8 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); + return new UTF8String(mb); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 71c086029cc5b..29a1411241cf6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -37,6 +37,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -414,7 +415,8 @@ public UTF8String getUTF8String(int ordinal) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + MemoryBlock mb = MemoryBlock.allocateFromObject(baseObject, baseOffset + offset, size); + return new UTF8String(mb); } @Override diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java index f37ef83ad92b4..883748932ad33 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/XXH64.java @@ -16,7 +16,10 @@ */ package org.apache.spark.sql.catalyst.expressions; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.MemoryBlock; // scalastyle: off /** @@ -71,13 +74,13 @@ public static long hashLong(long input, long seed) { return fmix(hash); } - public long hashUnsafeWords(Object base, long offset, int length) { - return hashUnsafeWords(base, offset, length, seed); + public long hashUnsafeWordsBlock(MemoryBlock mb) { + return hashUnsafeWordsBlock(mb, seed); } - public static long hashUnsafeWords(Object base, long offset, int length, long seed) { - assert (length % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; - long hash = hashBytesByWords(base, offset, length, seed); + public static long hashUnsafeWordsBlock(MemoryBlock mb, long seed) { + assert (mb.size() % 8 == 0) : "lengthInBytes must be a multiple of 8 (word-aligned)"; + long hash = hashBytesByWordsBlock(mb, seed); return fmix(hash); } @@ -85,26 +88,32 @@ public long hashUnsafeBytes(Object base, long offset, int length) { return hashUnsafeBytes(base, offset, length, seed); } - public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + public static long hashUnsafeBytesBlock(MemoryBlock mb, long seed) { + long offset = 0; + long length = mb.size(); assert (length >= 0) : "lengthInBytes cannot be negative"; - long hash = hashBytesByWords(base, offset, length, seed); + long hash = hashBytesByWordsBlock(mb, seed); long end = offset + length; offset += length & -8; if (offset + 4L <= end) { - hash ^= (Platform.getInt(base, offset) & 0xFFFFFFFFL) * PRIME64_1; + hash ^= (mb.getInt(offset) & 0xFFFFFFFFL) * PRIME64_1; hash = Long.rotateLeft(hash, 23) * PRIME64_2 + PRIME64_3; offset += 4L; } while (offset < end) { - hash ^= (Platform.getByte(base, offset) & 0xFFL) * PRIME64_5; + hash ^= (mb.getByte(offset) & 0xFFL) * PRIME64_5; hash = Long.rotateLeft(hash, 11) * PRIME64_1; offset++; } return fmix(hash); } + public static long hashUnsafeBytes(Object base, long offset, int length, long seed) { + return hashUnsafeBytesBlock(MemoryBlock.allocateFromObject(base, offset, length), seed); + } + private static long fmix(long hash) { hash ^= hash >>> 33; hash *= PRIME64_2; @@ -114,30 +123,31 @@ private static long fmix(long hash) { return hash; } - private static long hashBytesByWords(Object base, long offset, int length, long seed) { - long end = offset + length; + private static long hashBytesByWordsBlock(MemoryBlock mb, long seed) { + long offset = 0; + long length = mb.size(); long hash; if (length >= 32) { - long limit = end - 32; + long limit = length - 32; long v1 = seed + PRIME64_1 + PRIME64_2; long v2 = seed + PRIME64_2; long v3 = seed; long v4 = seed - PRIME64_1; do { - v1 += Platform.getLong(base, offset) * PRIME64_2; + v1 += mb.getLong(offset) * PRIME64_2; v1 = Long.rotateLeft(v1, 31); v1 *= PRIME64_1; - v2 += Platform.getLong(base, offset + 8) * PRIME64_2; + v2 += mb.getLong(offset + 8) * PRIME64_2; v2 = Long.rotateLeft(v2, 31); v2 *= PRIME64_1; - v3 += Platform.getLong(base, offset + 16) * PRIME64_2; + v3 += mb.getLong(offset + 16) * PRIME64_2; v3 = Long.rotateLeft(v3, 31); v3 *= PRIME64_1; - v4 += Platform.getLong(base, offset + 24) * PRIME64_2; + v4 += mb.getLong(offset + 24) * PRIME64_2; v4 = Long.rotateLeft(v4, 31); v4 *= PRIME64_1; @@ -178,9 +188,9 @@ private static long hashBytesByWords(Object base, long offset, int length, long hash += length; - long limit = end - 8; + long limit = length - 8; while (offset <= limit) { - long k1 = Platform.getLong(base, offset); + long k1 = mb.getLong(offset); hash ^= Long.rotateLeft(k1 * PRIME64_2, 31) * PRIME64_1; hash = Long.rotateLeft(hash, 27) * PRIME64_1 + PRIME64_4; offset += 8L; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 055ebf6c0da54..dbe3dff7ff6ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -359,10 +360,8 @@ abstract class HashExpression[E] extends Expression { } protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" + val mb = s"$input.getMemoryBlock()" + s"$result = $hasherClassName.hashUnsafeBytesBlock($mb, $result);" } protected def genHashForMap( @@ -464,6 +463,8 @@ abstract class InterpretedHashFunction { protected def hashUnsafeBytes(base: AnyRef, offset: Long, length: Int, seed: Long): Long + protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long + /** * Computes hash of a given `value` of type `dataType`. The caller needs to check the validity * of input `value`. @@ -489,8 +490,7 @@ abstract class InterpretedHashFunction { case c: CalendarInterval => hashInt(c.months, hashLong(c.microseconds, seed)) case a: Array[Byte] => hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) - case s: UTF8String => - hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + case s: UTF8String => hashUnsafeBytesBlock(s.getMemoryBlock(), seed) case array: ArrayData => val elementType = dataType match { @@ -577,9 +577,15 @@ object Murmur3HashFunction extends InterpretedHashFunction { Murmur3_x86_32.hashLong(l, seed.toInt) } - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { Murmur3_x86_32.hashUnsafeBytes(base, offset, len, seed.toInt) } + + override protected def hashUnsafeBytesBlock( + base: MemoryBlock, seed: Long): Long = { + Murmur3_x86_32.hashUnsafeBytesBlock(base, seed.toInt) + } } /** @@ -604,9 +610,14 @@ object XxHash64Function extends InterpretedHashFunction { override protected def hashLong(l: Long, seed: Long): Long = XXH64.hashLong(l, seed) - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { XXH64.hashUnsafeBytes(base, offset, len, seed) } + + override protected def hashUnsafeBytesBlock(base: MemoryBlock, seed: Long): Long = { + XXH64.hashUnsafeBytesBlock(base, seed) + } } /** @@ -713,10 +724,8 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """ override protected def genHashString(input: String, result: String): String = { - val baseObject = s"$input.getBaseObject()" - val baseOffset = s"$input.getBaseOffset()" - val numBytes = s"$input.numBytes()" - s"$result = $hasherClassName.hashUnsafeBytes($baseObject, $baseOffset, $numBytes);" + val mb = s"$input.getMemoryBlock()" + s"$result = $hasherClassName.hashUnsafeBytesBlock($mb);" } override protected def genHashForArray( @@ -804,10 +813,14 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashLong(l) } - override protected def hashUnsafeBytes(base: AnyRef, offset: Long, len: Int, seed: Long): Long = { + override protected def hashUnsafeBytes( + base: AnyRef, offset: Long, len: Int, seed: Long): Long = { HiveHasher.hashUnsafeBytes(base, offset, len) } + override protected def hashUnsafeBytesBlock( + base: MemoryBlock, seed: Long): Long = HiveHasher.hashUnsafeBytesBlock(base) + private val HIVE_DECIMAL_MAX_PRECISION = 38 private val HIVE_DECIMAL_MAX_SCALE = 38 diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java index b67c6f3e6e85e..8ffc1d7c24d61 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/HiveHasherSuite.java @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.types.UTF8String; import org.junit.Assert; import org.junit.Test; @@ -53,7 +55,7 @@ public void testKnownStringAndIntInputs() { for (int i = 0; i < inputs.length; i++) { UTF8String s = UTF8String.fromString("val_" + inputs[i]); - int hash = HiveHasher.hashUnsafeBytes(s.getBaseObject(), s.getBaseOffset(), s.numBytes()); + int hash = HiveHasher.hashUnsafeBytesBlock(s.getMemoryBlock()); Assert.assertEquals(expected[i], ((31 * inputs[i]) + hash)); } } @@ -89,13 +91,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - HiveHasher.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + HiveHasher.hashUnsafeBytesBlock(mb), + HiveHasher.hashUnsafeBytesBlock(mb)); - hashcodes.add(HiveHasher.hashUnsafeBytes( - bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); } // A very loose bound. @@ -112,13 +114,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - HiveHasher.hashUnsafeBytes(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + HiveHasher.hashUnsafeBytesBlock(mb), + HiveHasher.hashUnsafeBytesBlock(mb)); - hashcodes.add(HiveHasher.hashUnsafeBytes( - paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(HiveHasher.hashUnsafeBytesBlock(mb)); } // A very loose bound. diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java index 711887f02832a..bb7cad1212a26 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/XXH64Suite.java @@ -24,6 +24,8 @@ import java.util.Set; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.ByteArrayMemoryBlock; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.junit.Assert; import org.junit.Test; @@ -128,13 +130,13 @@ public void randomizedStressTestBytes() { int byteArrSize = rand.nextInt(100) * 8; byte[] bytes = new byte[byteArrSize]; rand.nextBytes(bytes); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(bytes); Assert.assertEquals( - hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWordsBlock(mb), + hasher.hashUnsafeWordsBlock(mb)); - hashcodes.add(hasher.hashUnsafeWords( - bytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); } // A very loose bound. @@ -151,13 +153,13 @@ public void randomizedStressTestPaddedStrings() { byte[] strBytes = String.valueOf(i).getBytes(StandardCharsets.UTF_8); byte[] paddedBytes = new byte[byteArrSize]; System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + MemoryBlock mb = ByteArrayMemoryBlock.fromArray(paddedBytes); Assert.assertEquals( - hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize), - hasher.hashUnsafeWords(paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hasher.hashUnsafeWordsBlock(mb), + hasher.hashUnsafeWordsBlock(mb)); - hashcodes.add(hasher.hashUnsafeWords( - paddedBytes, Platform.BYTE_ARRAY_OFFSET, byteArrSize)); + hashcodes.add(hasher.hashUnsafeWordsBlock(mb)); } // A very loose bound. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 754c26579ff08..4733f36174f42 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -206,7 +207,7 @@ public byte[] getBytes(int rowId, int count) { @Override protected UTF8String getBytesAsUTF8String(int rowId, int count) { - return UTF8String.fromAddress(null, data + rowId, count); + return new UTF8String(new OffHeapMemoryBlock(data + rowId, count)); } // diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index f8e37e995a17f..227a16f7e69e9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.annotation.InterfaceStability; import org.apache.spark.sql.execution.arrow.ArrowUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.memory.OffHeapMemoryBlock; import org.apache.spark.unsafe.types.UTF8String; /** @@ -377,9 +378,10 @@ final UTF8String getUTF8String(int rowId) { if (stringResult.isSet == 0) { return null; } else { - return UTF8String.fromAddress(null, + return new UTF8String(new OffHeapMemoryBlock( stringResult.buffer.memoryAddress() + stringResult.start, - stringResult.end - stringResult.start); + stringResult.end - stringResult.start + )); } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 50ae26a3ff9d9..470b93efd1974 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.benchmark import java.util.{Arrays, Comparator} import org.apache.spark.unsafe.array.LongArray -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.Sorter import org.apache.spark.util.collection.unsafe.sort._ @@ -36,7 +36,7 @@ import org.apache.spark.util.random.XORShiftRandom class SortBenchmark extends BenchmarkBase { private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { - val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + val sortBuffer = new LongArray(new OnHeapMemoryBlock(buf.size() * 8L)) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { override def compare( @@ -50,8 +50,8 @@ class SortBenchmark extends BenchmarkBase { private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { val ref = Array.tabulate[Long](size * 2) { i => rand } val extended = ref ++ Array.fill[Long](size * 2)(0) - (new LongArray(MemoryBlock.fromLongArray(ref)), - new LongArray(MemoryBlock.fromLongArray(extended))) + (new LongArray(OnHeapMemoryBlock.fromArray(ref)), + new LongArray(OnHeapMemoryBlock.fromArray(extended))) } ignore("sort") { @@ -60,7 +60,7 @@ class SortBenchmark extends BenchmarkBase { val benchmark = new Benchmark("radix sort " + size, size) benchmark.addTimerCase("reference TimSort key prefix array") { timer => val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY) timer.stopTiming() @@ -78,7 +78,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xff i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -90,7 +90,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong & 0xffff i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() @@ -102,7 +102,7 @@ class SortBenchmark extends BenchmarkBase { array(i) = rand.nextLong i += 1 } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) + val buf = new LongArray(OnHeapMemoryBlock.fromArray(array)) timer.startTiming() RadixSort.sort(buf, size, 0, 7, false, false) timer.stopTiming() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala index ffda33cf906c5..25ee95daa034c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -22,13 +22,13 @@ import java.io.File import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.OnHeapMemoryBlock import org.apache.spark.util.Utils class RowQueueSuite extends SparkFunSuite { test("in-memory queue") { - val page = MemoryBlock.fromLongArray(new Array[Long](1<<10)) + val page = new OnHeapMemoryBlock((1<<10) * 8L) val queue = new InMemoryRowQueue(page, 1) { override def close() {} }