diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 1501111a06655..60deb91aa8fc6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -43,7 +43,8 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( loadFactor: Double) extends Serializable { - require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity <= OpenHashSet.MAX_CAPACITY, + s"Can't make capacity bigger than ${OpenHashSet.MAX_CAPACITY} elements") require(initialCapacity >= 1, "Invalid initial capacity") require(loadFactor < 1.0, "Load factor must be less than 1.0") require(loadFactor > 0.0, "Load factor must be greater than 0.0") @@ -213,6 +214,8 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( */ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { val newCapacity = _capacity * 2 + require(newCapacity > 0 && newCapacity <= OpenHashSet.MAX_CAPACITY, + s"Can't contain more than ${(loadFactor * OpenHashSet.MAX_CAPACITY).toInt} elements") allocateFunc(newCapacity) val newBitset = new BitSet(newCapacity) val newData = new Array[T](newCapacity) @@ -266,9 +269,10 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( private[spark] object OpenHashSet { + val MAX_CAPACITY = 1 << 30 val INVALID_POS = -1 - val NONEXISTENCE_MASK = 0x80000000 - val POSITION_MASK = 0xEFFFFFF + val NONEXISTENCE_MASK = 1 << 31 + val POSITION_MASK = (1 << 31) - 1 /** * A set of specialized hash function implementation to avoid boxing hash code computation diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 94e011799921b..3066e9996abda 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -44,7 +44,7 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { val goodMap3 = new OpenHashMap[String, String](256) assert(goodMap3.size === 0) intercept[IllegalArgumentException] { - new OpenHashMap[String, Int](1 << 30) // Invalid map size: bigger than 2^29 + new OpenHashMap[String, Int](1 << 30 + 1) // Invalid map size: bigger than 2^30 } intercept[IllegalArgumentException] { new OpenHashMap[String, Int](-1) @@ -186,4 +186,14 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { map(null) = 0 assert(map.contains(null)) } + + test("support for more than 12M items") { + val cnt = 12000000 // 12M + val map = new OpenHashMap[Int, Int](cnt) + for (i <- 0 until cnt) { + map(i) = 1 + } + val numInvalidValues = map.iterator.count(_._2 == 0) + assertResult(0)(numInvalidValues) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index 462bc2f29f9f8..508e737b725bc 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -44,7 +44,7 @@ class PrimitiveKeyOpenHashMapSuite extends SparkFunSuite with Matchers { val goodMap3 = new PrimitiveKeyOpenHashMap[Int, Int](256) assert(goodMap3.size === 0) intercept[IllegalArgumentException] { - new PrimitiveKeyOpenHashMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29 + new PrimitiveKeyOpenHashMap[Int, Int](1 << 30 + 1) // Invalid map size: bigger than 2^30 } intercept[IllegalArgumentException] { new PrimitiveKeyOpenHashMap[Int, Int](-1)