Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Commit

Permalink
Avoid mod operator when fast alternative available (dotnet/coreclr#27299
Browse files Browse the repository at this point in the history
)

Signed-off-by: dotnet-bot <[email protected]>
  • Loading branch information
benaadams authored and stephentoub committed Oct 26, 2019
1 parent bc78f93 commit 36d9ba1
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 36 deletions.
84 changes: 52 additions & 32 deletions src/Common/src/CoreLib/System/Collections/Generic/Dictionary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ private struct Entry

private int[]? _buckets;
private Entry[]? _entries;
#if BIT64
private ulong _fastModMultiplier;
#endif
private int _count;
private int _freeList;
private int _freeCount;
Expand Down Expand Up @@ -330,16 +333,15 @@ private ref TValue FindValue(TKey key)
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}

int[]? buckets = _buckets;
ref Entry entry = ref Unsafe.NullRef<Entry>();
if (buckets != null)
if (_buckets != null)
{
Debug.Assert(_entries != null, "expected entries to be != null");
IEqualityComparer<TKey>? comparer = _comparer;
if (comparer == null)
{
uint hashCode = (uint)key.GetHashCode();
int i = buckets[hashCode % (uint)buckets.Length];
int i = GetBucket(hashCode);
Entry[]? entries = _entries;
uint collisionCount = 0;
if (default(TKey)! != null) // TODO-NULLABLE: default(T) == null warning (https://github.com/dotnet/roslyn/issues/34757)
Expand Down Expand Up @@ -407,7 +409,7 @@ private ref TValue FindValue(TKey key)
else
{
uint hashCode = (uint)comparer.GetHashCode(key);
int i = buckets[hashCode % (uint)buckets.Length];
int i = GetBucket(hashCode);
Entry[]? entries = _entries;
uint collisionCount = 0;
// Value in _buckets is 1-based; subtract 1 from i. We do it here so it fuses with the following conditional.
Expand Down Expand Up @@ -453,10 +455,16 @@ private ref TValue FindValue(TKey key)
private int Initialize(int capacity)
{
int size = HashHelpers.GetPrime(capacity);
int[] buckets = new int[size];
Entry[] entries = new Entry[size];

// Assign member variables after both arrays allocated to guard against corruption from OOM if second fails
_freeList = -1;
_buckets = new int[size];
_entries = new Entry[size];
#if BIT64
_fastModMultiplier = HashHelpers.GetFastModMultiplier((uint)size);
#endif
_buckets = buckets;
_entries = entries;

return size;
}
Expand All @@ -481,7 +489,7 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior)
uint hashCode = (uint)((comparer == null) ? key.GetHashCode() : comparer.GetHashCode(key));

uint collisionCount = 0;
ref int bucket = ref _buckets[hashCode % (uint)_buckets.Length];
ref int bucket = ref GetBucket(hashCode);
// Value in _buckets is 1-based
int i = bucket - 1;

Expand Down Expand Up @@ -625,7 +633,7 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior)
if (count == entries.Length)
{
Resize();
bucket = ref _buckets[hashCode % (uint)_buckets.Length];
bucket = ref GetBucket(hashCode);
}
index = count;
_count = count + 1;
Expand Down Expand Up @@ -716,7 +724,6 @@ private void Resize(int newSize, bool forceNewHashCodes)
Debug.Assert(_entries != null, "_entries should be non-null");
Debug.Assert(newSize >= _entries.Length);

int[] buckets = new int[newSize];
Entry[] entries = new Entry[newSize];

int count = _count;
Expand All @@ -734,19 +741,23 @@ private void Resize(int newSize, bool forceNewHashCodes)
}
}

// Assign member variables after both arrays allocated to guard against corruption from OOM if second fails
_buckets = new int[newSize];
#if BIT64
_fastModMultiplier = HashHelpers.GetFastModMultiplier((uint)newSize);
#endif
for (int i = 0; i < count; i++)
{
if (entries[i].next >= -1)
{
uint bucket = entries[i].hashCode % (uint)newSize;
ref int bucket = ref GetBucket(entries[i].hashCode);
// Value in _buckets is 1-based
entries[i].next = buckets[bucket] - 1;
entries[i].next = bucket - 1;
// Value in _buckets is 1-based
buckets[bucket] = i + 1;
bucket = i + 1;
}
}

_buckets = buckets;
_entries = entries;
}

Expand All @@ -760,17 +771,16 @@ public bool Remove(TKey key)
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}

int[]? buckets = _buckets;
Entry[]? entries = _entries;
if (buckets != null)
if (_buckets != null)
{
Debug.Assert(entries != null, "entries should be non-null");
Debug.Assert(_entries != null, "entries should be non-null");
uint collisionCount = 0;
uint hashCode = (uint)(_comparer?.GetHashCode(key) ?? key.GetHashCode());
uint bucket = hashCode % (uint)buckets.Length;
ref int bucket = ref GetBucket(hashCode);
Entry[]? entries = _entries;
int last = -1;
// Value in buckets is 1-based
int i = buckets[bucket] - 1;
int i = bucket - 1;
while (i >= 0)
{
ref Entry entry = ref entries[i];
Expand All @@ -780,7 +790,7 @@ public bool Remove(TKey key)
if (last < 0)
{
// Value in buckets is 1-based
buckets[bucket] = entry.next + 1;
bucket = entry.next + 1;
}
else
{
Expand Down Expand Up @@ -829,17 +839,16 @@ public bool Remove(TKey key, [MaybeNullWhen(false)] out TValue value)
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}

int[]? buckets = _buckets;
Entry[]? entries = _entries;
if (buckets != null)
if (_buckets != null)
{
Debug.Assert(entries != null, "entries should be non-null");
Debug.Assert(_entries != null, "entries should be non-null");
uint collisionCount = 0;
uint hashCode = (uint)(_comparer?.GetHashCode(key) ?? key.GetHashCode());
uint bucket = hashCode % (uint)buckets.Length;
ref int bucket = ref GetBucket(hashCode);
Entry[]? entries = _entries;
int last = -1;
// Value in buckets is 1-based
int i = buckets[bucket] - 1;
int i = bucket - 1;
while (i >= 0)
{
ref Entry entry = ref entries[i];
Expand All @@ -849,7 +858,7 @@ public bool Remove(TKey key, [MaybeNullWhen(false)] out TValue value)
if (last < 0)
{
// Value in buckets is 1-based
buckets[bucket] = entry.next + 1;
bucket = entry.next + 1;
}
else
{
Expand Down Expand Up @@ -982,6 +991,7 @@ public int EnsureCapacity(int capacity)
_version++;
if (_buckets == null)
return Initialize(capacity);

int newSize = HashHelpers.GetPrime(capacity);
Resize(newSize, forceNewHashCodes: false);
return newSize;
Expand Down Expand Up @@ -1011,8 +1021,8 @@ public void TrimExcess(int capacity)
{
if (capacity < Count)
ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.capacity);
int newSize = HashHelpers.GetPrime(capacity);

int newSize = HashHelpers.GetPrime(capacity);
Entry[]? oldEntries = _entries;
int currentCapacity = oldEntries == null ? 0 : oldEntries.Length;
if (newSize >= currentCapacity)
Expand All @@ -1022,7 +1032,6 @@ public void TrimExcess(int capacity)
_version++;
Initialize(newSize);
Entry[]? entries = _entries;
int[]? buckets = _buckets;
int count = 0;
for (int i = 0; i < oldCount; i++)
{
Expand All @@ -1031,11 +1040,11 @@ public void TrimExcess(int capacity)
{
ref Entry entry = ref entries![count];
entry = oldEntries[i];
uint bucket = hashCode % (uint)newSize;
ref int bucket = ref GetBucket(hashCode);
// Value in _buckets is 1-based
entry.next = buckets![bucket] - 1; // If we get here, we have entries, therefore buckets is not null.
entry.next = bucket - 1;
// Value in _buckets is 1-based
buckets[bucket] = count + 1;
bucket = count + 1;
count++;
}
}
Expand Down Expand Up @@ -1153,6 +1162,17 @@ void IDictionary.Remove(object key)
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private ref int GetBucket(uint hashCode)
{
int[] buckets = _buckets!;
#if BIT64
return ref buckets[HashHelpers.FastMod(hashCode, (uint)buckets.Length, _fastModMultiplier)];
#else
return ref buckets[hashCode % (uint)buckets.Length];
#endif
}

public struct Enumerator : IEnumerator<KeyValuePair<TKey, TValue>>,
IDictionaryEnumerator
{
Expand Down
29 changes: 25 additions & 4 deletions src/Common/src/CoreLib/System/Collections/HashHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System.Diagnostics;
using System.Runtime.CompilerServices;

namespace System.Collections
{
Expand All @@ -28,12 +29,14 @@ internal static partial class HashHelpers
// h1(key) + i*h2(key), 0 <= i < size. h2 and the size must be relatively prime.
// We prefer the low computation costs of higher prime numbers over the increased
// memory allocation of a fixed prime number i.e. when right sizing a HashSet.
public static readonly int[] primes = {
private static readonly int[] s_primes =
{
3, 7, 11, 17, 23, 29, 37, 47, 59, 71, 89, 107, 131, 163, 197, 239, 293, 353, 431, 521, 631, 761, 919,
1103, 1327, 1597, 1931, 2333, 2801, 3371, 4049, 4861, 5839, 7013, 8419, 10103, 12143, 14591,
17519, 21023, 25229, 30293, 36353, 43627, 52361, 62851, 75431, 90523, 108631, 130363, 156437,
187751, 225307, 270371, 324449, 389357, 467237, 560689, 672827, 807403, 968897, 1162687, 1395263,
1674319, 2009191, 2411033, 2893249, 3471899, 4166287, 4999559, 5999471, 7199369 };
1674319, 2009191, 2411033, 2893249, 3471899, 4166287, 4999559, 5999471, 7199369
};

public static bool IsPrime(int candidate)
{
Expand All @@ -55,9 +58,8 @@ public static int GetPrime(int min)
if (min < 0)
throw new ArgumentException(SR.Arg_HTCapacityOverflow);

for (int i = 0; i < primes.Length; i++)
foreach (int prime in s_primes)
{
int prime = primes[i];
if (prime >= min)
return prime;
}
Expand Down Expand Up @@ -86,5 +88,24 @@ public static int ExpandPrime(int oldSize)

return GetPrime(newSize);
}

#if BIT64
public static ulong GetFastModMultiplier(uint divisor)
=> ulong.MaxValue / divisor + 1;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static uint FastMod(uint value, uint divisor, ulong multiplier)
{
// Using fastmod from Daniel Lemire https://lemire.me/blog/2019/02/08/faster-remainders-when-the-divisor-is-a-constant-beating-compilers-and-libdivide/

ulong lowbits = multiplier * value;
// 64bit * 64bit => 128bit isn't currently supported by Math https://github.com/dotnet/corefx/issues/41822
// otherwise we'd want this to be (uint)Math.MultiplyHigh(lowbits, divisor)
uint high = (uint)((((ulong)(uint)lowbits * divisor >> 32) + (lowbits >> 32) * divisor) >> 32);

Debug.Assert(high == value % divisor);
return high;
}
#endif
}
}

0 comments on commit 36d9ba1

Please sign in to comment.