diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 2dd1dc3da96c9..07ee3d008a6ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -533,7 +533,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap def append(key: Long, row: UnsafeRow): Unit = { val sizeInBytes = row.getSizeInBytes if (sizeInBytes >= (1 << SIZE_BITS)) { - sys.error("Does not support row that is larger than 256M") + throw new UnsupportedOperationException("Does not support row that is larger than 256M") } if (key < minKey) { @@ -543,19 +543,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap maxKey = key } - // There is 8 bytes for the pointer to next value - if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET) { - val used = page.length - if (used >= (1 << 30)) { - sys.error("Can not build a HashedRelation that is larger than 8G") - } - ensureAcquireMemory(used * 8L * 2) - val newPage = new Array[Long](used * 2) - Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, - cursor - Platform.LONG_ARRAY_OFFSET) - page = newPage - freeMemory(used * 8L) - } + grow(row.getSizeInBytes) // copy the bytes of UnsafeRow val offset = cursor @@ -588,7 +576,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap growArray() } else if (numKeys > array.length / 2 * 0.75) { // The fill ratio should be less than 0.75 - sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys") + throw new UnsupportedOperationException( + "Cannot build HashedRelation with more than 1/3 billions unique keys") } } } else { @@ -599,6 +588,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } + private def grow(inputRowSize: Int): Unit = { + // There is 8 bytes for the pointer to next value + val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8 + if (neededNumWords > page.length) { + if (neededNumWords > (1 << 30)) { + throw new UnsupportedOperationException( + "Can not build a HashedRelation that is larger than 8G") + } + val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30)) + ensureAcquireMemory(newNumWords * 8L) + val newPage = new Array[Long](newNumWords.toInt) + Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, + cursor - Platform.LONG_ARRAY_OFFSET) + val used = page.length + page = newPage + freeMemory(used * 8L) + } + } + private def growArray(): Unit = { var old_array = array val n = array.length diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index ede63fea9606f..b575e5570a42c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.CompactBuffer @@ -253,6 +253,30 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { map.free() } + test("SPARK-24257: insert big values into LongToUnsafeRowMap") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Array[DataType](StringType)) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + + val key = 0L + // the page array is initialized with length 1 << 17 (1M bytes), + // so here we need a value larger than 1 << 18 (2M bytes), to trigger the bug + val bigStr = UTF8String.fromString("x" * (1 << 19)) + + map.append(key, unsafeProj(InternalRow(bigStr))) + map.optimize() + + val resultRow = new UnsafeRow(1) + assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr) + map.free() + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()