Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize Keccak state copy, remove temp array use #7663

Merged
merged 5 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Nethermind/Nethermind.Core/Crypto/Keccak.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public static Hash256 Compute(byte[]? input)
return OfAnEmptyString;
}

return new Hash256(KeccakHash.ComputeHashBytes(input));
return new Hash256(ValueKeccak.Compute(input));
}

[DebuggerStepThrough]
Expand Down
252 changes: 152 additions & 100 deletions src/Nethermind/Nethermind.Core/Crypto/KeccakHash.cs
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,6 @@ private static void KeccakF1600(Span<ulong> st)
st[0] = aba;
}

public static Span<byte> ComputeHash(ReadOnlySpan<byte> input, int size = HASH_SIZE)
{
Span<byte> output = new byte[size];
ComputeHash(input, output);
return output;
}

public static byte[] ComputeHashBytes(ReadOnlySpan<byte> input, int size = HASH_SIZE)
{
byte[] output = new byte[size];
Expand Down Expand Up @@ -356,47 +349,133 @@ public static uint[] ComputeBytesToUint(ReadOnlySpan<byte> input, int size)
// compute a Keccak hash (md) of given byte length from "in"
public static void ComputeHash(ReadOnlySpan<byte> input, Span<byte> output)
{
int size = output.Length;
int roundSize = GetRoundSize(size);
int roundSize = GetRoundSize(output.Length);
if (output.Length <= 0 || output.Length > STATE_SIZE)
{
ThrowBadKeccak();
}

Span<ulong> state = stackalloc ulong[STATE_SIZE / sizeof(ulong)];
Span<byte> temp = stackalloc byte[TEMP_BUFF_SIZE];
Span<byte> stateBytes = MemoryMarshal.AsBytes(state);

int remainingInputLength = input.Length;
for (; remainingInputLength >= roundSize; remainingInputLength -= roundSize, input = input[roundSize..])
if (input.Length == Address.Size)
{
// Hashing Address, 20 bytes which is uint+Vector128
Unsafe.As<byte, uint>(ref MemoryMarshal.GetReference(stateBytes)) =
Unsafe.As<byte, uint>(ref MemoryMarshal.GetReference(input));
Unsafe.As<byte, Vector128<byte>>(ref Unsafe.Add(ref MemoryMarshal.GetReference(stateBytes), sizeof(uint))) =
Unsafe.As<byte, Vector128<byte>>(ref Unsafe.Add(ref MemoryMarshal.GetReference(input), sizeof(uint)));
}
else if (input.Length == Vector256<byte>.Count)
{
// Hashing Hash256 or UInt256, 32 bytes
Unsafe.As<byte, Vector256<byte>>(ref MemoryMarshal.GetReference(stateBytes)) =
Unsafe.As<byte, Vector256<byte>>(ref MemoryMarshal.GetReference(input));
}
else if (input.Length >= roundSize)
{
ReadOnlySpan<ulong> input64 = MemoryMarshal.Cast<byte, ulong>(input[..roundSize]);
// Process full rounds
do
{
XorVectors(stateBytes, input.Slice(0, roundSize));
KeccakF(state);
input = input.Slice(roundSize);
} while (input.Length >= roundSize);

for (int i = 0; i < input64.Length; i++)
if (input.Length > 0)
{
state[i] ^= input64[i];
// XOR the remaining input bytes into the state
XorVectors(stateBytes, input);
}
}
else
{
input.CopyTo(stateBytes);
}

KeccakF(state);
// Apply terminator markers within the current block
stateBytes[input.Length] ^= 0x01; // Append bit '1' after remaining input
stateBytes[roundSize - 1] ^= 0x80; // Set the last bit of the round to '1'

// Process the final block
KeccakF(state);

if (output.Length == Vector256<byte>.Count)
LukaszRozmej marked this conversation as resolved.
Show resolved Hide resolved
{
// Fast Vector sized copy for Hash256
Unsafe.As<byte, Vector256<byte>>(ref MemoryMarshal.GetReference(output)) =
Unsafe.As<byte, Vector256<byte>>(ref MemoryMarshal.GetReference(stateBytes));
}
else if (output.Length == Vector512<byte>.Count)
{
// Fast Vector sized copy for Hash512
Unsafe.As<byte, Vector512<byte>>(ref MemoryMarshal.GetReference(output)) =
Unsafe.As<byte, Vector512<byte>>(ref MemoryMarshal.GetReference(stateBytes));
}
else
{
stateBytes[..output.Length].CopyTo(output);
}
}

// last block and padding
if (input.Length >= TEMP_BUFF_SIZE || input.Length > roundSize || roundSize + 1 >= TEMP_BUFF_SIZE || roundSize == 0 || roundSize - 1 >= TEMP_BUFF_SIZE)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void XorVectors(Span<byte> state, ReadOnlySpan<byte> input)
{
ref byte stateRef = ref MemoryMarshal.GetReference(state);
if (Vector512<byte>.IsSupported && input.Length >= Vector512<byte>.Count)
{
ThrowBadKeccak();
// Convert to uint for the mod else the Jit does a more complicated signed mod
// whereas as uint it just does an And
int vectorLength = input.Length - (int)((uint)input.Length % (uint)Vector512<byte>.Count);
ref byte inputRef = ref MemoryMarshal.GetReference(input);
for (int i = 0; i < vectorLength; i += Vector512<byte>.Count)
{
ref Vector512<byte> state256 = ref Unsafe.As<byte, Vector512<byte>>(ref Unsafe.Add(ref stateRef, i));
Vector512<byte> input256 = Unsafe.As<byte, Vector512<byte>>(ref Unsafe.Add(ref inputRef, i));
state256 = Vector512.Xor(state256, input256);
}

input = input.Slice(vectorLength);
stateRef = ref Unsafe.Add(ref stateRef, vectorLength);
}

input[..remainingInputLength].CopyTo(temp);
temp[remainingInputLength] = 1;
temp[roundSize - 1] |= 0x80;
if (Vector256<byte>.IsSupported && input.Length >= Vector256<byte>.Count)
{
// Convert to uint for the mod else the Jit does a more complicated signed mod
// whereas as uint it just does an And
int vectorLength = input.Length - (int)((uint)input.Length % (uint)Vector256<byte>.Count);
ref byte inputRef = ref MemoryMarshal.GetReference(input);
for (int i = 0; i < vectorLength; i += Vector256<byte>.Count)
{
ref Vector256<byte> state256 = ref Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref stateRef, i));
Vector256<byte> input256 = Unsafe.As<byte, Vector256<byte>>(ref Unsafe.Add(ref inputRef, i));
state256 = Vector256.Xor(state256, input256);
}

input = input.Slice(vectorLength);
stateRef = ref Unsafe.Add(ref stateRef, vectorLength);
}

Span<ulong> tempU64 = MemoryMarshal.Cast<byte, ulong>(temp[..roundSize]);
for (int i = 0; i < tempU64.Length; i++)
if (Vector128<byte>.IsSupported && input.Length >= Vector128<byte>.Count)
{
state[i] ^= tempU64[i];
int vectorLength = input.Length - (int)((uint)input.Length % (uint)Vector128<byte>.Count);
ref byte inputRef = ref MemoryMarshal.GetReference(input);
for (int i = 0; i < vectorLength; i += Vector128<byte>.Count)
{
ref Vector128<byte> state128 = ref Unsafe.As<byte, Vector128<byte>>(ref Unsafe.Add(ref stateRef, i));
Vector128<byte> input128 = Unsafe.As<byte, Vector128<byte>>(ref Unsafe.Add(ref inputRef, i));
state128 = Vector128.Xor(state128, input128);
}

input = input.Slice(vectorLength);
stateRef = ref Unsafe.Add(ref stateRef, vectorLength);
}

KeccakF(state);
MemoryMarshal.AsBytes(state[..(size / sizeof(ulong))]).CopyTo(output);
// Handle remaining elements
for (int i = 0; i < input.Length; i++)
{
Unsafe.Add(ref stateRef, i) ^= input[i];
}
}

public void Update(ReadOnlySpan<byte> input)
Expand All @@ -412,85 +491,58 @@ public void Update(ReadOnlySpan<byte> input)
return;
}

// If our provided state is empty, initialize a new one
ulong[] state = _state;
if (state.Length == 0)
{
// If our provided state is empty, initialize a new one
_state = state = Pool.RentState();
}

// If our remainder is non zero.
int offset = 0;
Span<byte> stateBytes = MemoryMarshal.AsBytes(state.AsSpan());

// Handle any existing remainder
if (_remainderLength != 0)
{
// Copy data to our remainder
ReadOnlySpan<byte> remainderAdditive = input[..Math.Min(input.Length, _roundSize - _remainderLength)];
remainderAdditive.CopyTo(_remainderBuffer.AsSpan(_remainderLength));

// Increment the length
_remainderLength += remainderAdditive.Length;
int bytesToFill = _roundSize - _remainderLength;
int bytesToCopy = Math.Min(input.Length, bytesToFill);

// Increment the input
input = input[remainderAdditive.Length..];
input.Slice(0, bytesToCopy).CopyTo(_remainderBuffer.AsSpan(_remainderLength));
_remainderLength += bytesToCopy;
offset += bytesToCopy;

// If our remainder length equals a full round
if (_remainderLength == _roundSize)
{
// Cast our input to ulongs.
Span<ulong> remainderBufferU64 = MemoryMarshal.Cast<byte, ulong>(_remainderBuffer.AsSpan(0, _roundSize));

// Eliminate bounds check for state for the loop
_ = state[remainderBufferU64.Length];
// Loop for each ulong in this remainder, and xor the state with the input.
for (int i = 0; i < remainderBufferU64.Length; i++)
{
state[i] ^= remainderBufferU64[i];
}

// Perform our KeccakF on our state.
// XOR the remainder buffer into the state using XorVectors
XorVectors(stateBytes, _remainderBuffer);
KeccakF(state);

// Clear remainder fields
// Reset remainder
_remainderLength = 0;
Pool.ReturnRemainder(ref _remainderBuffer);
}
}

// Loop for every round in our size.
while (input.Length >= _roundSize)
// Process full rounds
while (input.Length - offset >= _roundSize)
{
// Cast our input to ulongs.
ReadOnlySpan<ulong> input64 = MemoryMarshal.Cast<byte, ulong>(input[.._roundSize]);

// Eliminate bounds check for state for the loop
_ = state[input64.Length];
// Loop for each ulong in this round, and xor the state with the input.
for (int i = 0; i < input64.Length; i++)
{
state[i] ^= input64[i];
}

// Perform our KeccakF on our state.
XorVectors(stateBytes, input.Slice(offset, _roundSize));
KeccakF(state);

// Remove the input data processed this round.
input = input[_roundSize..];
offset += _roundSize;
}

// last block and padding
if (input.Length >= TEMP_BUFF_SIZE || input.Length > _roundSize || _roundSize + 1 >= TEMP_BUFF_SIZE || _roundSize == 0 || _roundSize - 1 >= TEMP_BUFF_SIZE)
{
ThrowBadKeccak();
}

// If we have any remainder here, it means any remainder was processed before, we can copy our data over and set our length
if (input.Length > 0)
// Handle remaining input (less than a full block)
int remainingInputLength = input.Length - offset;
if (remainingInputLength > 0)
{
if (_remainderBuffer.Length == 0)
{
_remainderBuffer = Pool.RentRemainder();
}
input.CopyTo(_remainderBuffer);
_remainderLength = input.Length;

input.Slice(offset).CopyTo(_remainderBuffer);
_remainderLength = remainingInputLength;
}
}

Expand All @@ -512,39 +564,39 @@ public void UpdateFinalTo(Span<byte> output)
ThrowHashingComplete();
}

ulong[] state = _state;
Span<byte> stateBytes = MemoryMarshal.AsBytes(state.AsSpan());

if (_remainderLength > 0)
{
Span<byte> remainder = _remainderBuffer.AsSpan(0, _roundSize);
// Set a 1 byte after the remainder.
remainder[_remainderLength++] = 1;
// XOR the remainder buffer into the state
XorVectors(stateBytes, _remainderBuffer.AsSpan(0, _remainderLength));
}

// Set the highest bit on the last byte.
remainder[_roundSize - 1] |= 0x80;
// Apply terminator markers within the current block
stateBytes[_remainderLength] ^= 0x01; // Append bit '1' after the input
stateBytes[_roundSize - 1] ^= 0x80; // Set the last bit of the block to '1'

// Cast the remainder buffer to ulongs.
Span<ulong> temp64 = MemoryMarshal.Cast<byte, ulong>(remainder);
// Loop for each ulong in this round, and xor the state with the input.
for (int i = 0; i < temp64.Length; i++)
{
_state[i] ^= temp64[i];
}
KeccakF(state);

Pool.ReturnRemainder(ref _remainderBuffer);
// Copy the hash output
if (output.Length == Vector256<byte>.Count)
{
// Fast Vector sized copy for Hash256
Unsafe.As<byte, Vector256<byte>>(ref MemoryMarshal.GetReference(output)) =
Unsafe.As<byte, Vector256<byte>>(ref MemoryMarshal.GetReference(stateBytes));
}
else if (output.Length == Vector512<byte>.Count)
{
// Fast Vector sized copy for Hash512
Unsafe.As<byte, Vector512<byte>>(ref MemoryMarshal.GetReference(output)) =
Unsafe.As<byte, Vector512<byte>>(ref MemoryMarshal.GetReference(stateBytes));
}
else
{
Span<byte> temp = MemoryMarshal.AsBytes<ulong>(_state);
// Xor 1 byte as first byte.
temp[0] ^= 1;
// Xor the highest bit on the last byte.
temp[_roundSize - 1] ^= 0x80;
stateBytes[..output.Length].CopyTo(output);
}

KeccakF(_state);

// Obtain the state data in the desired (hash) size we want.
MemoryMarshal.AsBytes<ulong>(_state)[..HashSize].CopyTo(output);

Pool.ReturnState(ref _state);
}

Expand Down
Loading