Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24257][SQL]LongToUnsafeRowMap calculate the new size may be wrong #21311

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
def append(key: Long, row: UnsafeRow): Unit = {
val sizeInBytes = row.getSizeInBytes
if (sizeInBytes >= (1 << SIZE_BITS)) {
sys.error("Does not support row that is larger than 256M")
throw new UnsupportedOperationException("Does not support row that is larger than 256M")
}

if (key < minKey) {
Expand All @@ -567,19 +567,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
maxKey = key
}

// There is 8 bytes for the pointer to next value
if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) {
val used = page.length
if (used >= (1 << 30)) {
sys.error("Can not build a HashedRelation that is larger than 8G")
}
ensureAcquireMemory(used * 8L * 2)
val newPage = new Array[Long](used * 2)
Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
cursor - Platform.LONG_ARRAY_OFFSET)
page = newPage
freeMemory(used * 8L)
}
grow(row.getSizeInBytes)

// copy the bytes of UnsafeRow
val offset = cursor
Expand Down Expand Up @@ -615,7 +603,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
growArray()
} else if (numKeys > array.length / 2 * 0.75) {
// The fill ratio should be less than 0.75
sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys")
throw new UnsupportedOperationException(
"Cannot build HashedRelation with more than 1/3 billions unique keys")
}
}
} else {
Expand All @@ -626,6 +615,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
}

private def grow(inputRowSize: Int): Unit = {
// There is 8 bytes for the pointer to next value
val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8
if (neededNumWords > page.length) {
if (neededNumWords > (1 << 30)) {
throw new UnsupportedOperationException(
"Can not build a HashedRelation that is larger than 8G")
}
val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30))
ensureAcquireMemory(newNumWords * 8L)
val newPage = new Array[Long](newNumWords.toInt)
Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
cursor - Platform.LONG_ARRAY_OFFSET)
val used = page.length
page = newPage
freeMemory(used * 8L)
}
}

private def growArray(): Unit = {
var old_array = array
val n = array.length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.CompactBuffer
Expand Down Expand Up @@ -254,6 +254,30 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
map.free()
}

test("LongToUnsafeRowMap with big values") {
val taskMemoryManager = new TaskMemoryManager(
new StaticMemoryManager(
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
Long.MaxValue,
Long.MaxValue,
1),
0)
val unsafeProj = UnsafeProjection.create(Array[DataType](StringType))
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)

val key = 0L
// the page array is initialized with length 1 << 17 (1M bytes),
// so here we need a value larger than 1 << 18 (2M bytes),to trigger the bug
val bigStr = UTF8String.fromString("x" * (1 << 22))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to double check, do we have to use 1 << 22 to trigger this bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary.
Just chose a larger value to make it easier to lose data.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean this bug can't be reproduced consistently? e.g. if we pick 1 << 18 + 1, we may not expose this bug, so we have to use 1 << 22 to 100% reproduce this bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LongToUnsafeRowMap#getRow
resultRow=UnsafeRow#pointTo(page(1<<18), baseOffset(16), sizeInBytes(1<<21+16))

UTF8String#getBytes
copyMemory(base(page), offset, bytes, BYTE_ARRAY_OFFSET, numBytes(1<<21+16));

In the case of similar size sometimes, can still read the original value.

When introducing SPARK-10399,UnsafeRow#getUTF8String check the size at this time.
If we pick 1 << 18 + 1, 100% reproduce this bug.

But when this patch is not introduced, differences that are too small sometimes do not trigger.
So I chose a larger value.

My understanding may be problematic. Please advise. Thank you.

        sun.misc.Unsafe unsafe;
        try {
            Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe");
            unsafeField.setAccessible(true);
            unsafe = (sun.misc.Unsafe) unsafeField.get(null);
        } catch (Throwable cause) {
            unsafe = null;
        }

        String value = "xxxxx";
        byte[] src = value.getBytes();

        byte[] dst = new byte[3];
        byte[] newDst = new byte[5];

        unsafe.copyMemory(src, 16, dst, 16, src.length);
        unsafe.copyMemory(dst, 16, newDst, 16, src.length);

        System.out.println("dst:" + new String(dst));
        System.out.println("newDst:" + new String(newDst));

output:

dst:xxx
newDst:xxxxx

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then 1 << 19 should be good enough as it doubles the size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I think so.


map.append(key, unsafeProj(InternalRow(bigStr)))
map.optimize()

val resultRow = new UnsafeRow(1)
assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr)
map.free()
}

test("Spark-14521") {
val ser = new KryoSerializer(
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()
Expand Down