Skip to content

Commit

Permalink
Inline more Avx2 helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
sbomer committed Apr 24, 2023
1 parent 0342acc commit ef9e53b
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,12 @@ internal static int IndexOfAnyVectorized<TNegator, TOptimizations>(ref short sea
Vector256<byte> result = IndexOfAnyLookup<TNegator, TOptimizations>(source0, source1, bitmap256);
if (result != Vector256<byte>.Zero)
{
return ComputeFirstIndex<short, TNegator>(ref searchSpace, ref currentSearchSpace, result);
result = Avx2.Permute4x64(result.AsInt64(), 0b_11_01_10_00).AsByte();

uint mask = TNegator.ExtractMask(result);

int offsetInVector = BitOperations.TrailingZeroCount(mask);
return offsetInVector + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref currentSearchSpace) / (nuint)sizeof(short));
}

currentSearchSpace = ref Unsafe.Add(ref currentSearchSpace, 2 * Vector256<short>.Count);
Expand All @@ -219,7 +224,18 @@ internal static int IndexOfAnyVectorized<TNegator, TOptimizations>(ref short sea
Vector256<byte> result = IndexOfAnyLookup<TNegator, TOptimizations>(source0, source1, bitmap256);
if (result != Vector256<byte>.Zero)
{
return ComputeFirstIndexOverlapped<short, TNegator>(ref searchSpace, ref firstVector, ref oneVectorAwayFromEnd, result);
result = Avx2.Permute4x64(result.AsInt64(), 0b_11_01_10_00).AsByte();

uint mask = TNegator.ExtractMask(result);

int offsetInVector = BitOperations.TrailingZeroCount(mask);
if (offsetInVector >= Vector256<short>.Count)
{
// We matched within the second vector
firstVector = ref oneVectorAwayFromEnd;
offsetInVector -= Vector256<short>.Count;
}
return offsetInVector + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref firstVector) / (nuint)sizeof(short));
}
}

Expand Down Expand Up @@ -307,7 +323,12 @@ internal static int LastIndexOfAnyVectorized<TNegator, TOptimizations>(ref short
Vector256<byte> result = IndexOfAnyLookup<TNegator, TOptimizations>(source0, source1, bitmap256);
if (result != Vector256<byte>.Zero)
{
return ComputeLastIndex<short, TNegator>(ref searchSpace, ref currentSearchSpace, result);
result = Avx2.Permute4x64(result.AsInt64(), 0b_11_01_10_00).AsByte();

uint mask = TNegator.ExtractMask(result);

int offsetInVector = 31 - BitOperations.LeadingZeroCount(mask);
return offsetInVector + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref currentSearchSpace) / (nuint)sizeof(short));
}
}
while (Unsafe.IsAddressGreaterThan(ref currentSearchSpace, ref twoVectorsAfterStart));
Expand All @@ -329,7 +350,18 @@ internal static int LastIndexOfAnyVectorized<TNegator, TOptimizations>(ref short
Vector256<byte> result = IndexOfAnyLookup<TNegator, TOptimizations>(source0, source1, bitmap256);
if (result != Vector256<byte>.Zero)
{
return ComputeLastIndexOverlapped<short, TNegator>(ref searchSpace, ref secondVector, result);
result = Avx2.Permute4x64(result.AsInt64(), 0b_11_01_10_00).AsByte();

uint mask = TNegator.ExtractMask(result);

int offsetInVector = 31 - BitOperations.LeadingZeroCount(mask);
if (offsetInVector < Vector256<short>.Count)
{
return offsetInVector;
}

// We matched within the second vector
return offsetInVector - Vector256<short>.Count + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref secondVector) / (nuint)sizeof(short));
}
}

Expand Down Expand Up @@ -411,7 +443,10 @@ internal static int IndexOfAnyVectorized<TNegator>(ref byte searchSpace, int sea
Vector256<byte> result = TNegator.NegateIfNeeded(IndexOfAnyLookupCore(source, bitmap256));
if (result != Vector256<byte>.Zero)
{
return ComputeFirstIndex<byte, TNegator>(ref searchSpace, ref currentSearchSpace, result);
uint mask = TNegator.ExtractMask(result);

int offsetInVector = BitOperations.TrailingZeroCount(mask);
return offsetInVector + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref currentSearchSpace) / (nuint)sizeof(byte));
}

currentSearchSpace = ref Unsafe.Add(ref currentSearchSpace, Vector256<byte>.Count);
Expand All @@ -436,7 +471,16 @@ internal static int IndexOfAnyVectorized<TNegator>(ref byte searchSpace, int sea
Vector256<byte> result = TNegator.NegateIfNeeded(IndexOfAnyLookupCore(source, bitmap256));
if (result != Vector256<byte>.Zero)
{
return ComputeFirstIndexOverlapped<byte, TNegator>(ref searchSpace, ref firstVector, ref halfVectorAwayFromEnd, result);
uint mask = TNegator.ExtractMask(result);

int offsetInVector = BitOperations.TrailingZeroCount(mask);
if (offsetInVector >= Vector256<short>.Count)
{
// We matched within the second vector
firstVector = ref halfVectorAwayFromEnd;
offsetInVector -= Vector256<short>.Count;
}
return offsetInVector + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref firstVector) / (nuint)sizeof(byte));
}
}

Expand Down Expand Up @@ -518,7 +562,10 @@ internal static int LastIndexOfAnyVectorized<TNegator>(ref byte searchSpace, int
Vector256<byte> result = TNegator.NegateIfNeeded(IndexOfAnyLookupCore(source, bitmap256));
if (result != Vector256<byte>.Zero)
{
return ComputeLastIndex<byte, TNegator>(ref searchSpace, ref currentSearchSpace, result);
uint mask = TNegator.ExtractMask(result);

int offsetInVector = 31 - BitOperations.LeadingZeroCount(mask);
return offsetInVector + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref currentSearchSpace) / (nuint)sizeof(byte));
}
}
while (Unsafe.IsAddressGreaterThan(ref currentSearchSpace, ref vectorAfterStart));
Expand All @@ -541,7 +588,16 @@ internal static int LastIndexOfAnyVectorized<TNegator>(ref byte searchSpace, int
Vector256<byte> result = TNegator.NegateIfNeeded(IndexOfAnyLookupCore(source, bitmap256));
if (result != Vector256<byte>.Zero)
{
return ComputeLastIndexOverlapped<byte, TNegator>(ref searchSpace, ref secondVector, result);
uint mask = TNegator.ExtractMask(result);

int offsetInVector = 31 - BitOperations.LeadingZeroCount(mask);
if (offsetInVector < Vector256<short>.Count)
{
return offsetInVector;
}

// We matched within the second vector
return offsetInVector - Vector256<short>.Count + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref secondVector) / (nuint)sizeof(byte));
}
}

Expand Down Expand Up @@ -622,7 +678,10 @@ internal static int IndexOfAnyVectorized<TNegator>(ref byte searchSpace, int sea
Vector256<byte> result = IndexOfAnyLookup<TNegator>(source, bitmap256_0, bitmap256_1);
if (result != Vector256<byte>.Zero)
{
return ComputeFirstIndex<byte, TNegator>(ref searchSpace, ref currentSearchSpace, result);
uint mask = TNegator.ExtractMask(result);

int offsetInVector = BitOperations.TrailingZeroCount(mask);
return offsetInVector + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref currentSearchSpace) / (nuint)sizeof(byte));
}

currentSearchSpace = ref Unsafe.Add(ref currentSearchSpace, Vector256<byte>.Count);
Expand All @@ -647,7 +706,16 @@ internal static int IndexOfAnyVectorized<TNegator>(ref byte searchSpace, int sea
Vector256<byte> result = IndexOfAnyLookup<TNegator>(source, bitmap256_0, bitmap256_1);
if (result != Vector256<byte>.Zero)
{
return ComputeFirstIndexOverlapped<byte, TNegator>(ref searchSpace, ref firstVector, ref halfVectorAwayFromEnd, result);
uint mask = TNegator.ExtractMask(result);

int offsetInVector = BitOperations.TrailingZeroCount(mask);
if (offsetInVector >= Vector256<short>.Count)
{
// We matched within the second vector
firstVector = ref halfVectorAwayFromEnd;
offsetInVector -= Vector256<short>.Count;
}
return offsetInVector + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref firstVector) / (nuint)sizeof(byte));
}
}

Expand Down Expand Up @@ -730,7 +798,10 @@ internal static int LastIndexOfAnyVectorized<TNegator>(ref byte searchSpace, int
Vector256<byte> result = IndexOfAnyLookup<TNegator>(source, bitmap256_0, bitmap256_1);
if (result != Vector256<byte>.Zero)
{
return ComputeLastIndex<byte, TNegator>(ref searchSpace, ref currentSearchSpace, result);
uint mask = TNegator.ExtractMask(result);

int offsetInVector = 31 - BitOperations.LeadingZeroCount(mask);
return offsetInVector + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref currentSearchSpace) / (nuint)sizeof(byte));
}
}
while (Unsafe.IsAddressGreaterThan(ref currentSearchSpace, ref vectorAfterStart));
Expand All @@ -753,7 +824,16 @@ internal static int LastIndexOfAnyVectorized<TNegator>(ref byte searchSpace, int
Vector256<byte> result = IndexOfAnyLookup<TNegator>(source, bitmap256_0, bitmap256_1);
if (result != Vector256<byte>.Zero)
{
return ComputeLastIndexOverlapped<byte, TNegator>(ref searchSpace, ref secondVector, result);
uint mask = TNegator.ExtractMask(result);

int offsetInVector = 31 - BitOperations.LeadingZeroCount(mask);
if (offsetInVector < Vector256<short>.Count)
{
return offsetInVector;
}

// We matched within the second vector
return offsetInVector - Vector256<short>.Count + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref secondVector) / (nuint)sizeof(byte));
}
}

Expand Down Expand Up @@ -992,89 +1072,6 @@ private static unsafe int ComputeLastIndexOverlapped<T, TNegator>(ref T searchSp
return offsetInVector - Vector128<short>.Count + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref secondVector) / (nuint)sizeof(T));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int ComputeFirstIndex<T, TNegator>(ref T searchSpace, ref T current, Vector256<byte> result)
where TNegator : struct, INegator
{
if (typeof(T) == typeof(short))
{
result = FixUpPackedVector256Result(result);
}

uint mask = TNegator.ExtractMask(result);

int offsetInVector = BitOperations.TrailingZeroCount(mask);
return offsetInVector + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref current) / (nuint)sizeof(T));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int ComputeFirstIndexOverlapped<T, TNegator>(ref T searchSpace, ref T current0, ref T current1, Vector256<byte> result)
where TNegator : struct, INegator
{
if (typeof(T) == typeof(short))
{
result = FixUpPackedVector256Result(result);
}

uint mask = TNegator.ExtractMask(result);

int offsetInVector = BitOperations.TrailingZeroCount(mask);
if (offsetInVector >= Vector256<short>.Count)
{
// We matched within the second vector
current0 = ref current1;
offsetInVector -= Vector256<short>.Count;
}
return offsetInVector + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref current0) / (nuint)sizeof(T));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int ComputeLastIndex<T, TNegator>(ref T searchSpace, ref T current, Vector256<byte> result)
where TNegator : struct, INegator
{
if (typeof(T) == typeof(short))
{
result = FixUpPackedVector256Result(result);
}

uint mask = TNegator.ExtractMask(result);

int offsetInVector = 31 - BitOperations.LeadingZeroCount(mask);
return offsetInVector + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref current) / (nuint)sizeof(T));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int ComputeLastIndexOverlapped<T, TNegator>(ref T searchSpace, ref T secondVector, Vector256<byte> result)
where TNegator : struct, INegator
{
if (typeof(T) == typeof(short))
{
result = FixUpPackedVector256Result(result);
}

uint mask = TNegator.ExtractMask(result);

int offsetInVector = 31 - BitOperations.LeadingZeroCount(mask);
if (offsetInVector < Vector256<short>.Count)
{
return offsetInVector;
}

// We matched within the second vector
return offsetInVector - Vector256<short>.Count + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref secondVector) / (nuint)sizeof(T));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector256<byte> FixUpPackedVector256Result(Vector256<byte> result)
{
Debug.Assert(Avx2.IsSupported);
// Avx2.PackUnsignedSaturate(Vector256.Create((short)1), Vector256.Create((short)2)) will result in
// 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2
// We want to swap the X and Y bits
// 1, 1, 1, 1, 1, 1, 1, 1, X, X, X, X, X, X, X, X, Y, Y, Y, Y, Y, Y, Y, Y, 2, 2, 2, 2, 2, 2, 2, 2
return Avx2.Permute4x64(result.AsInt64(), 0b_11_01_10_00).AsByte();
}

internal interface INegator
{
static abstract bool NegateIfNeeded(bool result);
Expand Down
Loading

0 comments on commit ef9e53b

Please sign in to comment.