Skip to content

Commit

Permalink
Implement Vector128 version of System.Buffers.Text.Base64 DecodeFromU…
Browse files Browse the repository at this point in the history
…tf8 and EncodeToUtf8 (#70654)

* Implement Vector128 version of System.Buffers.Text.Base64.DecodeFromUtf8

Rework the SS3 into a Vector128 version, and add Arm64 support.

* SSE3 improvements

* Remove superfluous bitwise And

* Add comment to SimdShuffle

* Inline SimdShuffle

* Implement Vector128 version of System.Buffers.Text.Base64.EncodeToUtf8

* Ensure masking on SSE3

Change-Id: I319f94cfc51d0542ae4eb11a8d48b3eb8180553f
CustomizedGitHooks: yes

* Restore asserts and move zero inside the loop

* Neater C# code

Change-Id: I2cbe14f4228f8035e7d213b5b58815c4eee35563
CustomizedGitHooks: yes

* Make SimdShuffle consistent across X64 and Arm64

* Better looking multiply
  • Loading branch information
a74nh authored Jun 20, 2022
1 parent 18ec279 commit d6d28e4
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 78 deletions.
109 changes: 64 additions & 45 deletions src/libraries/System.Memory/src/System/Buffers/Text/Base64Decoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.Arm;
using System.Runtime.Intrinsics.X86;

namespace System.Buffers.Text
{
// AVX2 version based on https://github.com/aklomp/base64/tree/e516d769a2a432c08404f1981e73b431566057be/lib/arch/avx2
// SSSE3 version based on https://github.com/aklomp/base64/tree/e516d769a2a432c08404f1981e73b431566057be/lib/arch/ssse3
// Vector128 version based on https://github.com/aklomp/base64/tree/e516d769a2a432c08404f1981e73b431566057be/lib/arch/ssse3

public static partial class Base64
{
Expand Down Expand Up @@ -74,9 +75,9 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan<byte> utf8, Spa
}

end = srcMax - 24;
if (Ssse3.IsSupported && (end >= src))
if ((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian && (end >= src))
{
Ssse3Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes);
Vector128Decode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes);

if (src == srcEnd)
goto DoneExit;
Expand Down Expand Up @@ -476,10 +477,28 @@ private static unsafe void Avx2Decode(ref byte* srcBytes, ref byte* destBytes, b
destBytes = dest;
}

// This can be replaced once https://github.com/dotnet/runtime/issues/63331 is implemented.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void Ssse3Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart)
private static Vector128<byte> SimdShuffle(Vector128<byte> left, Vector128<byte> right, Vector128<byte> mask8F)
{
// If we have SSSE3 support, pick off 16 bytes at a time for as long as we can,
Debug.Assert((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian);

if (Ssse3.IsSupported)
{
return Ssse3.Shuffle(left, right);
}
else
{
return AdvSimd.Arm64.VectorTableLookup(left, Vector128.BitwiseAnd(right, mask8F));
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void Vector128Decode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart)
{
Debug.Assert((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian);

// If we have Vector128 support, pick off 16 bytes at a time for as long as we can,
// but make sure that we quit before seeing any == markers at the end of the
// string. Also, because we write four zeroes at the end of the output, ensure
// that there are at least 6 valid bytes of input data remaining to close the
Expand Down Expand Up @@ -552,34 +571,15 @@ private static unsafe void Ssse3Decode(ref byte* srcBytes, ref byte* destBytes,
// 1111 0x10 andlut 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10 0x10

// The JIT won't hoist these "constants", so help it
Vector128<sbyte> lutHi = Vector128.Create(
0x10, 0x10, 0x01, 0x02,
0x04, 0x08, 0x04, 0x08,
0x10, 0x10, 0x10, 0x10,
0x10, 0x10, 0x10, 0x10);

Vector128<sbyte> lutLo = Vector128.Create(
0x15, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x13, 0x1A,
0x1B, 0x1B, 0x1B, 0x1A);

Vector128<sbyte> lutShift = Vector128.Create(
0, 16, 19, 4,
-65, -65, -71, -71,
0, 0, 0, 0,
0, 0, 0, 0);

Vector128<sbyte> packBytesMask = Vector128.Create(
2, 1, 0, 6,
5, 4, 10, 9,
8, 14, 13, 12,
-1, -1, -1, -1);

Vector128<sbyte> mask2F = Vector128.Create((sbyte)'/');
Vector128<sbyte> mergeConstant0 = Vector128.Create(0x01400140).AsSByte();
Vector128<byte> lutHi = Vector128.Create(0x02011010, 0x08040804, 0x10101010, 0x10101010).AsByte();
Vector128<byte> lutLo = Vector128.Create(0x11111115, 0x11111111, 0x1A131111, 0x1A1B1B1B).AsByte();
Vector128<sbyte> lutShift = Vector128.Create(0x04131000, 0xb9b9bfbf, 0x00000000, 0x00000000).AsSByte();
Vector128<sbyte> packBytesMask = Vector128.Create(0x06000102, 0x090A0405, 0x0C0D0E08, 0xffffffff).AsSByte();
Vector128<byte> mergeConstant0 = Vector128.Create(0x01400140).AsByte();
Vector128<short> mergeConstant1 = Vector128.Create(0x00011000).AsInt16();
Vector128<sbyte> zero = Vector128<sbyte>.Zero;
Vector128<byte> one = Vector128.Create((byte)1);
Vector128<byte> mask2F = Vector128.Create((byte)'/');
Vector128<byte> mask8F = Vector128.Create((byte)0x8F);

byte* src = srcBytes;
byte* dest = destBytes;
Expand All @@ -588,52 +588,71 @@ private static unsafe void Ssse3Decode(ref byte* srcBytes, ref byte* destBytes,
do
{
AssertRead<Vector128<sbyte>>(src, srcStart, sourceLength);
Vector128<sbyte> str = Sse2.LoadVector128(src).AsSByte();
Vector128<byte> str = Vector128.LoadUnsafe(ref *src);

// lookup
Vector128<sbyte> hiNibbles = Sse2.And(Sse2.ShiftRightLogical(str.AsInt32(), 4).AsSByte(), mask2F);
Vector128<sbyte> loNibbles = Sse2.And(str, mask2F);
Vector128<sbyte> hi = Ssse3.Shuffle(lutHi, hiNibbles);
Vector128<sbyte> lo = Ssse3.Shuffle(lutLo, loNibbles);
Vector128<byte> hiNibbles = Vector128.ShiftRightLogical(str.AsInt32(), 4).AsByte() & mask2F;
Vector128<byte> hi = SimdShuffle(lutHi, hiNibbles, mask8F);
Vector128<byte> lo = SimdShuffle(lutLo, str, mask8F);

// Check for invalid input: if any "and" values from lo and hi are not zero,
// fall back on bytewise code to do error checking and reporting:
if (Sse2.MoveMask(Sse2.CompareGreaterThan(Sse2.And(lo, hi), zero)) != 0)
if ((lo & hi) != Vector128<byte>.Zero)
break;

Vector128<sbyte> eq2F = Sse2.CompareEqual(str, mask2F);
Vector128<sbyte> shift = Ssse3.Shuffle(lutShift, Sse2.Add(eq2F, hiNibbles));
Vector128<byte> eq2F = Vector128.Equals(str, mask2F);
Vector128<byte> shift = SimdShuffle(lutShift.AsByte(), (eq2F + hiNibbles), mask8F);

// Now simply add the delta values to the input:
str = Sse2.Add(str, shift);
str += shift;

// in, bits, upper case are most significant bits, lower case are least significant bits
// 00llllll 00kkkkLL 00jjKKKK 00JJJJJJ
// 00iiiiii 00hhhhII 00ggHHHH 00GGGGGG
// 00ffffff 00eeeeFF 00ddEEEE 00DDDDDD
// 00cccccc 00bbbbCC 00aaBBBB 00AAAAAA

Vector128<short> merge_ab_and_bc = Ssse3.MultiplyAddAdjacent(str.AsByte(), mergeConstant0);
Vector128<short> merge_ab_and_bc;
if (Ssse3.IsSupported)
{
merge_ab_and_bc = Ssse3.MultiplyAddAdjacent(str.AsByte(), mergeConstant0.AsSByte());
}
else
{
Vector128<ushort> evens = AdvSimd.ShiftLeftLogicalWideningLower(AdvSimd.Arm64.UnzipEven(str, one).GetLower(), 6);
Vector128<ushort> odds = AdvSimd.Arm64.TransposeOdd(str, Vector128<byte>.Zero).AsUInt16();
merge_ab_and_bc = Vector128.Add(evens, odds).AsInt16();
}
// 0000kkkk LLllllll 0000JJJJ JJjjKKKK
// 0000hhhh IIiiiiii 0000GGGG GGggHHHH
// 0000eeee FFffffff 0000DDDD DDddEEEE
// 0000bbbb CCcccccc 0000AAAA AAaaBBBB

Vector128<int> output = Sse2.MultiplyAddAdjacent(merge_ab_and_bc, mergeConstant1);
Vector128<int> output;
if (Ssse3.IsSupported)
{
output = Sse2.MultiplyAddAdjacent(merge_ab_and_bc, mergeConstant1);
}
else
{
Vector128<int> ievens = AdvSimd.ShiftLeftLogicalWideningLower(AdvSimd.Arm64.UnzipEven(merge_ab_and_bc, one.AsInt16()).GetLower(), 12);
Vector128<int> iodds = AdvSimd.Arm64.TransposeOdd(merge_ab_and_bc, Vector128<short>.Zero).AsInt32();
output = Vector128.Add(ievens, iodds).AsInt32();
}
// 00000000 JJJJJJjj KKKKkkkk LLllllll
// 00000000 GGGGGGgg HHHHhhhh IIiiiiii
// 00000000 DDDDDDdd EEEEeeee FFffffff
// 00000000 AAAAAAaa BBBBbbbb CCcccccc

// Pack bytes together:
str = Ssse3.Shuffle(output.AsSByte(), packBytesMask);
str = SimdShuffle(output.AsByte(), packBytesMask.AsByte(), mask8F);
// 00000000 00000000 00000000 00000000
// LLllllll KKKKkkkk JJJJJJjj IIiiiiii
// HHHHhhhh GGGGGGgg FFffffff EEEEeeee
// DDDDDDdd CCcccccc BBBBbbbb AAAAAAaa

AssertWrite<Vector128<sbyte>>(dest, destStart, destLength);
Sse2.Store(dest, str.AsByte());
str.Store(dest);

src += 16;
dest += 12;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.Arm;
using System.Runtime.Intrinsics.X86;

namespace System.Buffers.Text
{
// AVX2 version based on https://github.com/aklomp/base64/tree/e516d769a2a432c08404f1981e73b431566057be/lib/arch/avx2
// SSSE3 version based on https://github.com/aklomp/base64/tree/e516d769a2a432c08404f1981e73b431566057be/lib/arch/ssse3
// Vector128 version based on https://github.com/aklomp/base64/tree/e516d769a2a432c08404f1981e73b431566057be/lib/arch/ssse3

/// <summary>
/// Convert between binary data and UTF-8 encoded text that is represented in base 64.
Expand Down Expand Up @@ -75,9 +76,9 @@ public static unsafe OperationStatus EncodeToUtf8(ReadOnlySpan<byte> bytes, Span
}

end = srcMax - 16;
if (Ssse3.IsSupported && (end >= src))
if ((Ssse3.IsSupported || AdvSimd.Arm64.IsSupported) && BitConverter.IsLittleEndian && (end >= src))
{
Ssse3Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes);
Vector128Encode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes);

if (src == srcEnd)
goto DoneExit;
Expand Down Expand Up @@ -395,7 +396,7 @@ private static unsafe void Avx2Encode(ref byte* srcBytes, ref byte* destBytes, b
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void Ssse3Encode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart)
private static unsafe void Vector128Encode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart)
{
// If we have SSSE3 support, pick off 12 bytes at a time for as long as we can.
// But because we read 16 bytes at a time, ensure we have enough room to do a
Expand All @@ -405,24 +406,15 @@ private static unsafe void Ssse3Encode(ref byte* srcBytes, ref byte* destBytes,
// 0 0 0 0 l k j i h g f e d c b a

// The JIT won't hoist these "constants", so help it
Vector128<sbyte> shuffleVec = Vector128.Create(
1, 0, 2, 1,
4, 3, 5, 4,
7, 6, 8, 7,
10, 9, 11, 10);

Vector128<sbyte> lut = Vector128.Create(
65, 71, -4, -4,
-4, -4, -4, -4,
-4, -4, -4, -4,
-19, -16, 0, 0);

Vector128<sbyte> maskAC = Vector128.Create(0x0fc0fc00).AsSByte();
Vector128<sbyte> maskBB = Vector128.Create(0x003f03f0).AsSByte();
Vector128<byte> shuffleVec = Vector128.Create(0x01020001, 0x04050304, 0x07080607, 0x0A0B090A).AsByte();
Vector128<byte> lut = Vector128.Create(0xFCFC4741, 0xFCFCFCFC, 0xFCFCFCFC, 0x0000F0ED).AsByte();
Vector128<byte> maskAC = Vector128.Create(0x0fc0fc00).AsByte();
Vector128<byte> maskBB = Vector128.Create(0x003f03f0).AsByte();
Vector128<ushort> shiftAC = Vector128.Create(0x04000040).AsUInt16();
Vector128<short> shiftBB = Vector128.Create(0x01000010).AsInt16();
Vector128<byte> const51 = Vector128.Create((byte)51);
Vector128<sbyte> const25 = Vector128.Create((sbyte)25);
Vector128<short> shiftBB = Vector128.Create(0x01000010).AsInt16();
Vector128<byte> const51 = Vector128.Create((byte)51);
Vector128<sbyte> const25 = Vector128.Create((sbyte)25);
Vector128<byte> mask8F = Vector128.Create((byte)0x8F);

byte* src = srcBytes;
byte* dest = destBytes;
Expand All @@ -431,42 +423,52 @@ private static unsafe void Ssse3Encode(ref byte* srcBytes, ref byte* destBytes,
do
{
AssertRead<Vector128<sbyte>>(src, srcStart, sourceLength);
Vector128<sbyte> str = Sse2.LoadVector128(src).AsSByte();
Vector128<byte> str = Vector128.LoadUnsafe(ref *src);

// Reshuffle
str = Ssse3.Shuffle(str, shuffleVec);
str = SimdShuffle(str, shuffleVec, mask8F);
// str, bytes MSB to LSB:
// k l j k
// h i g h
// e f d e
// b c a b

Vector128<sbyte> t0 = Sse2.And(str, maskAC);
Vector128<byte> t0 = str & maskAC;
// bits, upper case are most significant bits, lower case are least significant bits
// 0000kkkk LL000000 JJJJJJ00 00000000
// 0000hhhh II000000 GGGGGG00 00000000
// 0000eeee FF000000 DDDDDD00 00000000
// 0000bbbb CC000000 AAAAAA00 00000000

Vector128<sbyte> t2 = Sse2.And(str, maskBB);
Vector128<byte> t2 = str & maskBB;
// 00000000 00llllll 000000jj KKKK0000
// 00000000 00iiiiii 000000gg HHHH0000
// 00000000 00ffffff 000000dd EEEE0000
// 00000000 00cccccc 000000aa BBBB0000

Vector128<ushort> t1 = Sse2.MultiplyHigh(t0.AsUInt16(), shiftAC);
Vector128<ushort> t1;
if (Ssse3.IsSupported)
{
t1 = Sse2.MultiplyHigh(t0.AsUInt16(), shiftAC);
}
else
{
Vector128<ushort> odd = Vector128.ShiftRightLogical(AdvSimd.Arm64.UnzipOdd(t0.AsUInt16(), t0.AsUInt16()), 6);
Vector128<ushort> even = Vector128.ShiftRightLogical(AdvSimd.Arm64.UnzipEven(t0.AsUInt16(), t0.AsUInt16()), 10);
t1 = AdvSimd.Arm64.ZipLow(even, odd);
}
// 00000000 00kkkkLL 00000000 00JJJJJJ
// 00000000 00hhhhII 00000000 00GGGGGG
// 00000000 00eeeeFF 00000000 00DDDDDD
// 00000000 00bbbbCC 00000000 00AAAAAA

Vector128<short> t3 = Sse2.MultiplyLow(t2.AsInt16(), shiftBB);
Vector128<short> t3 = t2.AsInt16() * shiftBB;
// 00llllll 00000000 00jjKKKK 00000000
// 00iiiiii 00000000 00ggHHHH 00000000
// 00ffffff 00000000 00ddEEEE 00000000
// 00cccccc 00000000 00aaBBBB 00000000

str = Sse2.Or(t1.AsSByte(), t3.AsSByte());
str = t1.AsByte() | t3.AsByte();
// 00llllll 00kkkkLL 00jjKKKK 00JJJJJJ
// 00iiiiii 00hhhhII 00ggHHHH 00GGGGGG
// 00ffffff 00eeeeFF 00ddEEEE 00DDDDDD
Expand All @@ -484,19 +486,27 @@ private static unsafe void Ssse3Encode(ref byte* srcBytes, ref byte* destBytes,

// Create LUT indices from input:
// the index for range #0 is right, others are 1 less than expected:
Vector128<byte> indices = Sse2.SubtractSaturate(str.AsByte(), const51);
Vector128<byte> indices;
if (Ssse3.IsSupported)
{
indices = Sse2.SubtractSaturate(str.AsByte(), const51);
}
else
{
indices = AdvSimd.SubtractSaturate(str.AsByte(), const51);
}

// mask is 0xFF (-1) for range #[1..4] and 0x00 for range #0:
Vector128<sbyte> mask = Sse2.CompareGreaterThan(str, const25);
Vector128<sbyte> mask = Vector128.GreaterThan(str.AsSByte(), const25);

// substract -1, so add 1 to indices for range #[1..4], All indices are now correct:
Vector128<sbyte> tmp = Sse2.Subtract(indices.AsSByte(), mask);
Vector128<sbyte> tmp = indices.AsSByte() - mask;

// Add offsets to input values:
str = Sse2.Add(str, Ssse3.Shuffle(lut, tmp));
str += SimdShuffle(lut, tmp.AsByte(), mask8F);

AssertWrite<Vector128<sbyte>>(dest, destStart, destLength);
Sse2.Store(dest, str.AsByte());
str.Store(dest);

src += 12;
dest += 16;
Expand Down

0 comments on commit d6d28e4

Please sign in to comment.