Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various Hashing Improvements #660

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@ public class ByteArrayOrdinalMap {

private long[] pointersByOrdinal;

private static int nextLargestPrime(int num) {
int[] precomputedPrimes = new int[]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pick prime numbers close to these numbers that are also (2^n-1), then we can replace the modulo with simpler bitwise operations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Note the sizable gaps between sequential Mersenne primes (like 127 and 521), which might impact resizing memory efficiency. Considering we're not using modulo during LP, I think the current approach is acceptable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restricting to mersenne primes will result in significant memory overhead -- each mersenne number is ~double the previous, and many of these are not prime, so an increase will often result in a significant step larger than doubling the space.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah they seem too few and far between 👍

{257, 521, 1049, 2099, 4139, 8287, 16553, 33107, 66221, 132421, 264839, 529673, 1059343, 2118661, 4237319, 8474633,
16949269, 33898507, 67796999, 135593987, 271187993, 542375947, 1084751881, Integer.MAX_VALUE};

for(int prime : precomputedPrimes) {
if(prime >= num) {
return prime;
}
}

return Integer.MAX_VALUE;
}

/**
* Creates a byte array ordinal map with a an initial capacity of 256 elements,
* and a load factor of 70%.
Expand All @@ -72,7 +86,7 @@ public ByteArrayOrdinalMap() {
* rounded up to the nearest power of two, and a load factor of 70%.
*/
public ByteArrayOrdinalMap(int size) {
size = bucketSize(size);
size = nextLargestPrime(size);

this.freeOrdinalTracker = new FreeOrdinalTracker();
this.byteData = new ByteDataArray(WastefulRecycler.DEFAULT_INSTANCE);
Expand Down Expand Up @@ -132,16 +146,18 @@ private synchronized int assignOrdinal(ByteDataArray serializedRepresentation, i
/// Note that this also requires pointersAndOrdinals be volatile so resizes are also visible
AtomicLongArray pao = pointersAndOrdinals;

int modBitmask = pao.length() - 1;
int bucket = hash & modBitmask;
int bucket = indexFromHash(hash, pao.length());
long key = pao.get(bucket);

while (key != EMPTY_BUCKET_VALUE) {
if (compare(serializedRepresentation, key)) {
return (int) (key >>> BITS_PER_POINTER);
}

bucket = (bucket + 1) & modBitmask;
bucket = (bucket + 1);
if(bucket == pao.length()) {
bucket = 0;
}
key = pao.get(bucket);
}

Expand Down Expand Up @@ -192,6 +208,11 @@ private int findFreeOrdinal(int preferredOrdinal) {
return freeOrdinalTracker.getFreeOrdinal();
}

private static int indexFromHash(int hashedValue, int length) {
int modulus = hashedValue % length;
return modulus < 0 ? modulus + length : modulus;
}

/**
* Assign a predefined ordinal to a serialized representation.<p>
* <p>
Expand All @@ -215,12 +236,14 @@ public void put(ByteDataArray serializedRepresentation, int ordinal) {

AtomicLongArray pao = pointersAndOrdinals;

int modBitmask = pao.length() - 1;
int bucket = hash & modBitmask;
int bucket = indexFromHash(hash, pao.length());
long key = pao.get(bucket);

while (key != EMPTY_BUCKET_VALUE) {
bucket = (bucket + 1) & modBitmask;
bucket = (bucket + 1);
if(bucket==pao.length()) {
bucket = 0;
}
key = pao.get(bucket);
}

Expand Down Expand Up @@ -295,8 +318,7 @@ public int get(ByteDataArray serializedRepresentation) {
private int get(ByteDataArray serializedRepresentation, int hash) {
AtomicLongArray pao = pointersAndOrdinals;

int modBitmask = pao.length() - 1;
int bucket = hash & modBitmask;
int bucket = indexFromHash(hash, pao.length());
long key = pao.get(bucket);

// Linear probing to resolve collisions
Expand All @@ -310,7 +332,10 @@ private int get(ByteDataArray serializedRepresentation, int hash) {
return (int) (key >>> BITS_PER_POINTER);
}

bucket = (bucket + 1) & modBitmask;
bucket = (bucket + 1);
if(bucket == pao.length()) {
bucket = 0;
}
key = pao.get(bucket);
}

Expand Down Expand Up @@ -484,7 +509,7 @@ private boolean compare(ByteDataArray serializedRepresentation, long key) {
* @param size the size to increase to, rounded up to the nearest power of two.
*/
public void resize(int size) {
size = bucketSize(size);
size = nextLargestPrime(size);

if (pointersAndOrdinals.length() < size) {
growKeyArray(size);
Expand All @@ -507,7 +532,7 @@ private void growKeyArray() {

private void growKeyArray(int newSize) {
AtomicLongArray pao = pointersAndOrdinals;
assert (newSize & (newSize - 1)) == 0; // power of 2
//assert (newSize & (newSize - 1)) == 0; // power of 2
assert pao.length() < newSize;

AtomicLongArray newKeys = emptyKeyArray(newSize);
Expand Down Expand Up @@ -545,15 +570,15 @@ private void populateNewHashArray(AtomicLongArray newKeys, long[] valuesToAdd) {
private void populateNewHashArray(AtomicLongArray newKeys, long[] valuesToAdd, int length) {
assert length <= valuesToAdd.length;

int modBitmask = newKeys.length() - 1;

for (int i = 0; i < length; i++) {
long value = valuesToAdd[i];
if (value != EMPTY_BUCKET_VALUE) {
int hash = rehashPreviouslyAddedData(value);
int bucket = hash & modBitmask;
int bucket = indexFromHash(hash, newKeys.length());
while (newKeys.get(bucket) != EMPTY_BUCKET_VALUE) {
bucket = (bucket + 1) & modBitmask;
bucket = (bucket + 1);
if(bucket == newKeys.length())
bucket = 0;
}
// Volatile store not required, could use plain store
// See VarHandles for JDK >= 9
Expand All @@ -578,13 +603,9 @@ private int rehashPreviouslyAddedData(long key) {
* Create an AtomicLongArray of the specified size, each value in the array will be EMPTY_BUCKET_VALUE
*/
private AtomicLongArray emptyKeyArray(int size) {
AtomicLongArray arr = new AtomicLongArray(size);
// Volatile store not required, could use plain store
// See VarHandles for JDK >= 9
for (int i = 0; i < arr.length(); i++) {
arr.lazySet(i, EMPTY_BUCKET_VALUE);
}
return arr;
long[] mtArray = new long[size];
Arrays.fill(mtArray, EMPTY_BUCKET_VALUE);
return new AtomicLongArray(mtArray);
}

public ByteDataArray getByteData() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,15 +201,33 @@ private int hashFromIndex(int index) {
return hashKeyRecord(fieldObjects);
}

// We want to keep hash table sizes prime numbers, but we don't want to have to compute very large primes during
// runtime. These are precomputed primes, each ~2x the size of the previous
private static int nextLargestPrime(int num) {
int[] precomputedPrimes = new int[]
{4139, 8287, 16553, 33107, 66221, 132421, 264839, 529673, 1059343, 2118661, 4237319, 8474633,
16949269, 33898507, 67796999, 135593987, 271187993, 542375947, 1084751881, Integer.MAX_VALUE};

for(int prime : precomputedPrimes) {
if(prime >= num) {
return prime;
}
}

return Integer.MAX_VALUE;
}

private void expandAndRehashTable() {
prepareForRead();

int[] newTable = new int[hashToAssignedOrdinal.length*2];
int newTableSize = nextLargestPrime(hashToAssignedOrdinal.length*2);

int[] newTable = new int[newTableSize];
Arrays.fill(newTable, ORDINAL_NONE);

int[][] newFieldMappings = new int[primaryKey.numFields()][hashToAssignedOrdinal.length*2];
IntList[][] newFieldHashToOrdinal = new IntList[primaryKey.numFields()][hashToAssignedOrdinal.length*2];
assignedOrdinalToIndex = Arrays.copyOf(assignedOrdinalToIndex, hashToAssignedOrdinal.length*2);
int[][] newFieldMappings = new int[primaryKey.numFields()][newTableSize];
IntList[][] newFieldHashToOrdinal = new IntList[primaryKey.numFields()][newTableSize];
assignedOrdinalToIndex = Arrays.copyOf(assignedOrdinalToIndex, newTableSize);

for(int fieldIdx=0;fieldIdx<primaryKey.numFields();fieldIdx++) {
IntList[] hashToOrdinal = fieldHashToAssignedOrdinal[fieldIdx];
Expand Down
Loading