From 208b974f93c0ee35a171eba374ca36aa9a78e930 Mon Sep 17 00:00:00 2001 From: Alex Covington <68252706+alexcovington@users.noreply.github.com> Date: Thu, 19 Dec 2024 13:15:10 -0800 Subject: [PATCH] Add vector support to System.Numerics.Tensors.TensorPrimitives.LeadingZeroCount for Byte and Int16 (#110333) * Working for byte * Working for ushort * Cleanup * Formatting * Update src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.LeadingZeroCount.cs Co-authored-by: Tanner Gooding * Add comment for usage of PermuteVar64x8x2 --------- Co-authored-by: Alex Covington (Advanced Micro Devices Co-authored-by: Tanner Gooding --- .../TensorPrimitives.LeadingZeroCount.cs | 119 ++++++++++++++++-- 1 file changed, 112 insertions(+), 7 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.LeadingZeroCount.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.LeadingZeroCount.cs index 11018124a5992..0bc08e97166c2 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.LeadingZeroCount.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.LeadingZeroCount.cs @@ -29,7 +29,8 @@ public static void LeadingZeroCount(ReadOnlySpan x, Span destination) internal readonly unsafe struct LeadingZeroCountOperator : IUnaryOperator where T : IBinaryInteger { public static bool Vectorizable => - (Avx512CD.VL.IsSupported && (sizeof(T) == 4 || sizeof(T) == 8)) || + (Avx512CD.VL.IsSupported && (sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8)) || + (Avx512Vbmi.VL.IsSupported && sizeof(T) == 1) || (AdvSimd.IsSupported && (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4)); public static T Invoke(T x) => T.LeadingZeroCount(x); @@ -37,10 +38,43 @@ public static void LeadingZeroCount(ReadOnlySpan x, Span destination) [MethodImpl(MethodImplOptions.AggressiveInlining)] public static Vector128 Invoke(Vector128 x) { + if (Avx512Vbmi.VL.IsSupported && sizeof(T) == 1) + { + Vector128 lookupVectorLow = Vector128.Create((byte)8, 7, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4); + Vector128 lookupVectorHigh = Vector128.Create((byte)3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0); + Vector128 nibbleMask = Vector128.Create(0xF); + Vector128 permuteMask = Vector128.Create(0x80); + Vector128 lowNibble = x.AsByte() & nibbleMask; + Vector128 highNibble = Sse2.ShiftRightLogical(x.AsInt32(), 4).AsByte() & nibbleMask; + Vector128 nibbleSelectMask = Sse2.CompareEqual(highNibble, Vector128.Zero); + Vector128 indexVector = Sse41.BlendVariable(highNibble, lowNibble, nibbleSelectMask) + + (~nibbleSelectMask & nibbleMask); + indexVector |= ~nibbleSelectMask & permuteMask; + return Avx512Vbmi.VL.PermuteVar16x8x2(lookupVectorLow, indexVector, lookupVectorHigh).As(); + } + if (Avx512CD.VL.IsSupported) { - if (sizeof(T) == 4) return Avx512CD.VL.LeadingZeroCount(x.AsUInt32()).As(); - if (sizeof(T) == 8) return Avx512CD.VL.LeadingZeroCount(x.AsUInt64()).As(); + if (sizeof(T) == 2) + { + Vector128 lowHalf = Vector128.Create((uint)0x0000FFFF); + Vector128 x_bot16 = Sse2.Or(Sse2.ShiftLeftLogical(x.AsUInt32(), 16), lowHalf); + Vector128 x_top16 = Sse2.Or(x.AsUInt32(), lowHalf); + Vector128 lz_bot16 = Avx512CD.VL.LeadingZeroCount(x_bot16); + Vector128 lz_top16 = Avx512CD.VL.LeadingZeroCount(x_top16); + Vector128 lz_top16_shift = Sse2.ShiftLeftLogical(lz_top16, 16); + return Sse2.Or(lz_bot16, lz_top16_shift).AsUInt16().As(); + } + + if (sizeof(T) == 4) + { + return Avx512CD.VL.LeadingZeroCount(x.AsUInt32()).As(); + } + + if (sizeof(T) == 8) + { + return Avx512CD.VL.LeadingZeroCount(x.AsUInt64()).As(); + } } Debug.Assert(AdvSimd.IsSupported); @@ -56,10 +90,42 @@ public static Vector128 Invoke(Vector128 x) [MethodImpl(MethodImplOptions.AggressiveInlining)] public static Vector256 Invoke(Vector256 x) { + if (Avx512Vbmi.VL.IsSupported && sizeof(T) == 1) + { + Vector256 lookupVector = + Vector256.Create((byte)8, 7, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, + 3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0); + Vector256 nibbleMask = Vector256.Create(0xF); + Vector256 lowNibble = x.AsByte() & nibbleMask; + Vector256 highNibble = Avx2.ShiftRightLogical(x.AsInt32(), 4).AsByte() & nibbleMask; + Vector256 nibbleSelectMask = Avx2.CompareEqual(highNibble, Vector256.Zero); + Vector256 indexVector = Avx2.BlendVariable(highNibble, lowNibble, nibbleSelectMask) + + (~nibbleSelectMask & nibbleMask); + return Avx512Vbmi.VL.PermuteVar32x8(lookupVector, indexVector).As(); + } + if (Avx512CD.VL.IsSupported) { - if (sizeof(T) == 4) return Avx512CD.VL.LeadingZeroCount(x.AsUInt32()).As(); - if (sizeof(T) == 8) return Avx512CD.VL.LeadingZeroCount(x.AsUInt64()).As(); + if (sizeof(T) == 2) + { + Vector256 lowHalf = Vector256.Create((uint)0x0000FFFF); + Vector256 x_bot16 = Avx2.Or(Avx2.ShiftLeftLogical(x.AsUInt32(), 16), lowHalf); + Vector256 x_top16 = Avx2.Or(x.AsUInt32(), lowHalf); + Vector256 lz_bot16 = Avx512CD.VL.LeadingZeroCount(x_bot16); + Vector256 lz_top16 = Avx512CD.VL.LeadingZeroCount(x_top16); + Vector256 lz_top16_shift = Avx2.ShiftLeftLogical(lz_top16, 16); + return Avx2.Or(lz_bot16, lz_top16_shift).AsUInt16().As(); + } + + if (sizeof(T) == 4) + { + return Avx512CD.VL.LeadingZeroCount(x.AsUInt32()).As(); + } + + if (sizeof(T) == 8) + { + return Avx512CD.VL.LeadingZeroCount(x.AsUInt64()).As(); + } } return Vector256.Create(Invoke(x.GetLower()), Invoke(x.GetUpper())); @@ -68,10 +134,49 @@ public static Vector256 Invoke(Vector256 x) [MethodImpl(MethodImplOptions.AggressiveInlining)] public static Vector512 Invoke(Vector512 x) { + if (Avx512BW.IsSupported && Avx512Vbmi.IsSupported && sizeof(T) == 1) + { + // Use each element of x as an index into a lookup table. + // Lookup can be broken down into the following: + // Bit 7 is set -- Result is 0, else result is from lookup table + // Bit 6 is set -- Use lookupVectorB, else use lookupVectorA + // Bit 5:0 -- Index to use for lookup table + Vector512 lookupVectorA = + Vector512.Create((byte)8, 7, 6, 6, 5, 5, 5, 5, + 4, 4, 4, 4, 4, 4, 4, 4, + 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, + 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2); + Vector512 lookupVectorB = Vector512.Create((byte)1); + Vector512 bit7ZeroMask = Avx512BW.CompareLessThan(x.AsByte(), Vector512.Create((byte)128)); + return Avx512F.And(bit7ZeroMask, Avx512Vbmi.PermuteVar64x8x2(lookupVectorA, x.AsByte(), lookupVectorB)).As(); + } + if (Avx512CD.IsSupported) { - if (sizeof(T) == 4) return Avx512CD.LeadingZeroCount(x.AsUInt32()).As(); - if (sizeof(T) == 8) return Avx512CD.LeadingZeroCount(x.AsUInt64()).As(); + if (sizeof(T) == 2) + { + Vector512 lowHalf = Vector512.Create((uint)0x0000FFFF); + Vector512 x_bot16 = Avx512F.Or(Avx512F.ShiftLeftLogical(x.AsUInt32(), 16), lowHalf); + Vector512 x_top16 = Avx512F.Or(x.AsUInt32(), lowHalf); + Vector512 lz_bot16 = Avx512CD.LeadingZeroCount(x_bot16); + Vector512 lz_top16 = Avx512CD.LeadingZeroCount(x_top16); + Vector512 lz_top16_shift = Avx512F.ShiftLeftLogical(lz_top16, 16); + return Avx512F.Or(lz_bot16, lz_top16_shift).AsUInt16().As(); + } + + if (sizeof(T) == 4) + { + return Avx512CD.LeadingZeroCount(x.AsUInt32()).As(); + } + + if (sizeof(T) == 8) + { + return Avx512CD.LeadingZeroCount(x.AsUInt64()).As(); + } } return Vector512.Create(Invoke(x.GetLower()), Invoke(x.GetUpper()));