Skip to content

Commit

Permalink
Revert "Avx512 Keccak (#7561)" (#7569)
Browse files Browse the repository at this point in the history
  • Loading branch information
benaadams authored Oct 7, 2024
1 parent dac7890 commit 0944765
Showing 1 changed file with 0 additions and 137 deletions.
137 changes: 0 additions & 137 deletions src/Nethermind/Nethermind.Core/Crypto/KeccakHash.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;

using static System.Numerics.BitOperations;

// ReSharper disable InconsistentNaming
Expand Down Expand Up @@ -92,18 +88,6 @@ private KeccakHash(int size)

// update the state with given number of rounds
private static void KeccakF(Span<ulong> st)
{
if (Avx512F.IsSupported)
{
KeccakF1600Avx512F(st);
}
else
{
KeccakF1600(st);
}
}

private static void KeccakF1600(Span<ulong> st)
{
Debug.Assert(st.Length == 25);

Expand Down Expand Up @@ -637,126 +621,5 @@ public static void ReturnState(ref ulong[] state)
state = Array.Empty<ulong>();
}
}

[SkipLocalsInit]
public static void KeccakF1600Avx512F(Span<ulong> state)
{
{
// Redundant statement that removes all the in loop bounds checks
_ = state[24];
}

// Can straight load and over-read for start elements
Vector512<ulong> mask = Vector512.Create(ulong.MaxValue, ulong.MaxValue, ulong.MaxValue, ulong.MaxValue, ulong.MaxValue, 0UL, 0UL, 0UL);
Vector512<ulong> c0 = Unsafe.As<ulong, Vector512<ulong>>(ref MemoryMarshal.GetReference(state));
// Clear the over-read values from first vectors
c0 = Vector512.BitwiseAnd(mask, c0);
Vector512<ulong> c1 = Unsafe.As<ulong, Vector512<ulong>>(ref Unsafe.Add(ref MemoryMarshal.GetReference(state), 5));
c1 = Vector512.BitwiseAnd(mask, c1);
Vector512<ulong> c2 = Unsafe.As<ulong, Vector512<ulong>>(ref Unsafe.Add(ref MemoryMarshal.GetReference(state), 10));
c2 = Vector512.BitwiseAnd(mask, c2);
Vector512<ulong> c3 = Unsafe.As<ulong, Vector512<ulong>>(ref Unsafe.Add(ref MemoryMarshal.GetReference(state), 15));
c3 = Vector512.BitwiseAnd(mask, c3);

// Can't over-read for the last elements (8 items in vector 5 to be remaining)
// so read a Vector256 and ulong then combine
Vector256<ulong> c4a = Unsafe.As<ulong, Vector256<ulong>>(ref Unsafe.Add(ref MemoryMarshal.GetReference(state), 20));
Vector256<ulong> c4b = Vector256.Create(state[24], 0UL, 0UL, 0UL);
Vector512<ulong> c4 = Vector512.Create(c4a, c4b);

ulong[] roundConstants = RoundConstants;
for (int round = 0; round < roundConstants.Length; round++)
{
// Theta step
Vector512<ulong> bVec = Vector512.Xor(Vector512.Xor(Vector512.Xor(c0, c1), Vector512.Xor(c2, c3)), c4);

// Compute Theta Vector
Vector512<ulong> bVecRot1 = Avx512F.PermuteVar8x64(bVec, Vector512.Create(1UL, 2UL, 3UL, 4UL, 0UL, 5UL, 6UL, 7UL));
Vector512<ulong> bVecRot4 = Avx512F.PermuteVar8x64(bVec, Vector512.Create(4UL, 0UL, 1UL, 2UL, 3UL, 5UL, 6UL, 7UL));

// Rotate bVecRot1 left by 1
Vector512<ulong> bVecRot1ShiftedLeft = Avx512F.ShiftLeftLogical(bVecRot1, 1);
Vector512<ulong> bVecRot1ShiftedRight = Avx512F.ShiftRightLogical(bVecRot1, 63);
Vector512<ulong> bVecRot1Rotated = Avx512F.Or(bVecRot1ShiftedLeft, bVecRot1ShiftedRight);

Vector512<ulong> thetaVec = Avx512F.Xor(bVecRot4, bVecRot1Rotated);

c0 = Avx512F.Xor(c0, thetaVec);
c1 = Avx512F.Xor(c1, thetaVec);
c2 = Avx512F.Xor(c2, thetaVec);
c3 = Avx512F.Xor(c3, thetaVec);
c4 = Avx512F.Xor(c4, thetaVec);

// Rho step
Vector512<ulong> rhoVec0 = Vector512.Create(0UL, 1UL, 62UL, 28UL, 27UL, 0UL, 0UL, 0UL);
c0 = Avx512F.RotateLeftVariable(c0, rhoVec0);

Vector512<ulong> rhoVec1 = Vector512.Create(36UL, 44UL, 6UL, 55UL, 20UL, 0UL, 0UL, 0UL);
c1 = Avx512F.RotateLeftVariable(c1, rhoVec1);

Vector512<ulong> rhoVec2 = Vector512.Create(3UL, 10UL, 43UL, 25UL, 39UL, 0UL, 0UL, 0UL);
c2 = Avx512F.RotateLeftVariable(c2, rhoVec2);

Vector512<ulong> rhoVec3 = Vector512.Create(41UL, 45UL, 15UL, 21UL, 8UL, 0UL, 0UL, 0UL);
c3 = Avx512F.RotateLeftVariable(c3, rhoVec3);

Vector512<ulong> rhoVec4 = Vector512.Create(18UL, 2UL, 61UL, 56UL, 14UL, 0UL, 0UL, 0UL);
c4 = Avx512F.RotateLeftVariable(c4, rhoVec4);

// Pi step
Vector512<ulong> c0Pi = Avx512F.PermuteVar8x64x2(c0, Vector512.Create(0UL, 8 + 1, 2, 3, 4, 5, 6, 7), c1);
c0Pi = Avx512F.PermuteVar8x64x2(c0Pi, Vector512.Create(0UL, 1, 8 + 2, 3, 4, 5, 6, 7), c2);
c0Pi = Avx512F.PermuteVar8x64x2(c0Pi, Vector512.Create(0UL, 1, 2, 8 + 3, 4, 5, 6, 7), c3);
c0Pi = Avx512F.PermuteVar8x64x2(c0Pi, Vector512.Create(0UL, 1, 2, 3, 8 + 4, 5, 6, 7), c4);

Vector512<ulong> c1Pi = Avx512F.PermuteVar8x64x2(c0, Vector512.Create(3UL, 8 + 4, 2, 3, 4, 5, 6, 7), c1);
c1Pi = Avx512F.PermuteVar8x64x2(c1Pi, Vector512.Create(0UL, 1, 8 + 0, 3, 4, 5, 6, 7), c2);
c1Pi = Avx512F.PermuteVar8x64x2(c1Pi, Vector512.Create(0UL, 1, 2, 8 + 1, 4, 5, 6, 7), c3);
c1Pi = Avx512F.PermuteVar8x64x2(c1Pi, Vector512.Create(0UL, 1, 2, 3, 8 + 2, 5, 6, 7), c4);

Vector512<ulong> c2Pi = Avx512F.PermuteVar8x64x2(c0, Vector512.Create(1UL, 8 + 2, 2, 3, 4, 5, 6, 7), c1);
c2Pi = Avx512F.PermuteVar8x64x2(c2Pi, Vector512.Create(0UL, 1, 8 + 3, 3, 4, 5, 6, 7), c2);
c2Pi = Avx512F.PermuteVar8x64x2(c2Pi, Vector512.Create(0UL, 1, 2, 8 + 4, 4, 5, 6, 7), c3);
c2Pi = Avx512F.PermuteVar8x64x2(c2Pi, Vector512.Create(0UL, 1, 2, 3, 8 + 0, 5, 6, 7), c4);

Vector512<ulong> c3Pi = Avx512F.PermuteVar8x64x2(c0, Vector512.Create(4UL, 8 + 0, 2, 3, 4, 5, 6, 7), c1);
c3Pi = Avx512F.PermuteVar8x64x2(c3Pi, Vector512.Create(0UL, 1, 8 + 1, 3, 4, 5, 6, 7), c2);
c3Pi = Avx512F.PermuteVar8x64x2(c3Pi, Vector512.Create(0UL, 1, 2, 8 + 2, 4, 5, 6, 7), c3);
c3Pi = Avx512F.PermuteVar8x64x2(c3Pi, Vector512.Create(0UL, 1, 2, 3, 8 + 3, 5, 6, 7), c4);

Vector512<ulong> c4Pi = Avx512F.PermuteVar8x64x2(c0, Vector512.Create(2UL, 8 + 3, 2, 3, 4, 5, 6, 7), c1);
c4Pi = Avx512F.PermuteVar8x64x2(c4Pi, Vector512.Create(0UL, 1, 8 + 4, 3, 4, 5, 6, 7), c2);
c4Pi = Avx512F.PermuteVar8x64x2(c4Pi, Vector512.Create(0UL, 1, 2, 8 + 0, 4, 5, 6, 7), c3);
c4Pi = Avx512F.PermuteVar8x64x2(c4Pi, Vector512.Create(0UL, 1, 2, 3, 8 + 1, 5, 6, 7), c4);

c0 = c0Pi;
c1 = c1Pi;
c2 = c2Pi;
c3 = c3Pi;
c4 = c4Pi;

// Chi step
Vector512<ulong> permute1 = Vector512.Create(1UL, 2UL, 3UL, 4UL, 0UL, 5UL, 6UL, 7UL);
Vector512<ulong> permute2 = Vector512.Create(2UL, 3UL, 4UL, 0UL, 1UL, 5UL, 6UL, 7UL);

c0 = Avx512F.TernaryLogic(c0, Avx512F.PermuteVar8x64(c0, permute1), Avx512F.PermuteVar8x64(c0, permute2), 0xD2);
c1 = Avx512F.TernaryLogic(c1, Avx512F.PermuteVar8x64(c1, permute1), Avx512F.PermuteVar8x64(c1, permute2), 0xD2);
c2 = Avx512F.TernaryLogic(c2, Avx512F.PermuteVar8x64(c2, permute1), Avx512F.PermuteVar8x64(c2, permute2), 0xD2);
c3 = Avx512F.TernaryLogic(c3, Avx512F.PermuteVar8x64(c3, permute1), Avx512F.PermuteVar8x64(c3, permute2), 0xD2);
c4 = Avx512F.TernaryLogic(c4, Avx512F.PermuteVar8x64(c4, permute1), Avx512F.PermuteVar8x64(c4, permute2), 0xD2);

// Iota step
c0 = Vector512.Xor(c0, Vector512.Create(roundConstants[round], 0UL, 0UL, 0UL, 0UL, 0UL, 0UL, 0UL));
}

// Can over-write for first elements
Unsafe.As<ulong, Vector512<ulong>>(ref MemoryMarshal.GetReference(state)) = c0;
Unsafe.As<ulong, Vector512<ulong>>(ref Unsafe.Add(ref MemoryMarshal.GetReference(state), 5)) = c1;
Unsafe.As<ulong, Vector512<ulong>>(ref Unsafe.Add(ref MemoryMarshal.GetReference(state), 10)) = c2;
Unsafe.As<ulong, Vector512<ulong>>(ref Unsafe.Add(ref MemoryMarshal.GetReference(state), 15)) = c3;
// Can't over-write for last elements so write the upper Vector256 and then ulong
Unsafe.As<ulong, Vector256<ulong>>(ref Unsafe.Add(ref MemoryMarshal.GetReference(state), 20)) = c3.GetUpper();
state[24] = c4.GetElement(4);
}
}
}

0 comments on commit 0944765

Please sign in to comment.