From 276c2930a67c84e83a63e57b7c532caf7378c154 Mon Sep 17 00:00:00 2001 From: sychen Date: Thu, 24 May 2018 11:02:09 +0800 Subject: [PATCH] [SPARK-24257][SQL] LongToUnsafeRowMap calculate the new size may be wrong LongToUnsafeRowMap has a mistake when growing its page array: it blindly grows to `oldSize * 2`, while the new record may be larger than `oldSize * 2`. Then we may have a malformed UnsafeRow when querying this map, whose actual data is smaller than its declared size, and the data is corrupted. Author: sychen Closes #21311 from cxzl25/fix_LongToUnsafeRowMap_page_size. --- .../sql/execution/joins/HashedRelation.scala | 38 +++++++++++-------- .../execution/joins/HashedRelationSuite.scala | 26 ++++++++++++- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 1465346eb802d..20ce01f4ce8cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -557,7 +557,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap def append(key: Long, row: UnsafeRow): Unit = { val sizeInBytes = row.getSizeInBytes if (sizeInBytes >= (1 << SIZE_BITS)) { - sys.error("Does not support row that is larger than 256M") + throw new UnsupportedOperationException("Does not support row that is larger than 256M") } if (key < minKey) { @@ -567,19 +567,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap maxKey = key } - // There is 8 bytes for the pointer to next value - if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) { - val used = page.length - if (used >= (1 << 30)) { - sys.error("Can not build a HashedRelation that is larger than 8G") - } - ensureAcquireMemory(used * 8L * 2) - val newPage = new Array[Long](used * 2) - Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, - cursor - Platform.LONG_ARRAY_OFFSET) - page = newPage - freeMemory(used * 8L) - } + grow(row.getSizeInBytes) // copy the bytes of UnsafeRow val offset = cursor @@ -615,7 +603,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap growArray() } else if (numKeys > array.length / 2 * 0.75) { // The fill ratio should be less than 0.75 - sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys") + throw new UnsupportedOperationException( + "Cannot build HashedRelation with more than 1/3 billions unique keys") } } } else { @@ -626,6 +615,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } + private def grow(inputRowSize: Int): Unit = { + // There is 8 bytes for the pointer to next value + val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8 + if (neededNumWords > page.length) { + if (neededNumWords > (1 << 30)) { + throw new UnsupportedOperationException( + "Can not build a HashedRelation that is larger than 8G") + } + val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30)) + ensureAcquireMemory(newNumWords * 8L) + val newPage = new Array[Long](newNumWords.toInt) + Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, + cursor - Platform.LONG_ARRAY_OFFSET) + val used = page.length + page = newPage + freeMemory(used * 8L) + } + } + private def growArray(): Unit = { var old_array = array val n = array.length diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 51f8c3325fdff..037cc2e3ccad7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.CompactBuffer @@ -254,6 +254,30 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { map.free() } + test("SPARK-24257: insert big values into LongToUnsafeRowMap") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Array[DataType](StringType)) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + + val key = 0L + // the page array is initialized with length 1 << 17 (1M bytes), + // so here we need a value larger than 1 << 18 (2M bytes), to trigger the bug + val bigStr = UTF8String.fromString("x" * (1 << 19)) + + map.append(key, unsafeProj(InternalRow(bigStr))) + map.optimize() + + val resultRow = new UnsafeRow(1) + assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr) + map.free() + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()