Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vector support to System.Numerics.Tensors.TensorPrimitives.LeadingZeroCount for Byte and Int16 #110333

Merged
merged 11 commits into from
Dec 19, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public static void LeadingZeroCount<T>(ReadOnlySpan<T> x, Span<T> destination)
internal readonly unsafe struct LeadingZeroCountOperator<T> : IUnaryOperator<T, T> where T : IBinaryInteger<T>
{
public static bool Vectorizable =>
(Avx512CD.VL.IsSupported && (sizeof(T) == 4 || sizeof(T) == 8)) ||
(Avx512CD.VL.IsSupported && (sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8)) ||
(Avx512BW.IsSupported && Avx512Vbmi.IsSupported && Avx512Vbmi.VL.IsSupported && sizeof(T) == 1) ||
alexcovington marked this conversation as resolved.
Show resolved Hide resolved
(AdvSimd.IsSupported && (sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4));

public static T Invoke(T x) => T.LeadingZeroCount(x);
Expand All @@ -41,6 +42,30 @@ public static Vector128<T> Invoke(Vector128<T> x)
{
if (sizeof(T) == 4) return Avx512CD.VL.LeadingZeroCount(x.AsUInt32()).As<uint, T>();
if (sizeof(T) == 8) return Avx512CD.VL.LeadingZeroCount(x.AsUInt64()).As<ulong, T>();
if (sizeof(T) == 2)
{
Vector128<uint> lowHalf = Vector128.Create((uint)0x0000FFFF);
Vector128<uint> x_bot16 = Sse2.Or(Sse2.ShiftLeftLogical(x.AsUInt32(), 16), lowHalf);
Vector128<uint> x_top16 = Sse2.Or(x.AsUInt32(), lowHalf);
Vector128<uint> lz_bot16 = Avx512CD.VL.LeadingZeroCount(x_bot16);
Vector128<uint> lz_top16 = Avx512CD.VL.LeadingZeroCount(x_top16);
Vector128<uint> lz_top16_shift = Sse2.ShiftLeftLogical(lz_top16, 16);
return Sse2.Or(lz_bot16, lz_top16_shift).AsUInt16().As<ushort, T>();
}
}
if (Avx512Vbmi.VL.IsSupported && sizeof(T) == 1)
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
{
Vector128<byte> lookupVectorLow = Vector128.Create((byte)8, 7, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4);
Vector128<byte> lookupVectorHigh = Vector128.Create((byte)3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0);
Vector128<byte> nibbleMask = Vector128.Create<byte>(0xF);
Vector128<byte> permuteMask = Vector128.Create<byte>(0x80);
Vector128<byte> lowNibble = x.AsByte() & nibbleMask;
Vector128<byte> highNibble = Sse2.ShiftRightLogical(x.AsInt32(), 4).AsByte() & nibbleMask;
Vector128<byte> nibbleSelectMask = Sse2.CompareEqual(highNibble, Vector128<byte>.Zero);
Vector128<byte> indexVector = Sse41.BlendVariable(highNibble, lowNibble, nibbleSelectMask) +
(~nibbleSelectMask & nibbleMask);
indexVector |= ~nibbleSelectMask & permuteMask;
return Avx512Vbmi.VL.PermuteVar16x8x2(lookupVectorLow, indexVector, lookupVectorHigh).As<byte, T>();
Comment on lines +43 to +53
Copy link
Member

@tannergooding tannergooding Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this cheaper than:

Vector512<int> x32 = Avx512F.ConvertToVector512Int32(x.AsByte());
Vector512<int> lz = Avx512CD.LeadingZeroCount(x32);
return Avx512F.ConvertToVector128Byte(lz) - Vector128.Create(24);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is overhead when widening-unwidening.

For this case, the widening here gives a bimodal performance result. To verify, the same microbenchmark can be modified to stress this path specifically by using BufferLength=16.

Some runs look like this:

| Method           | Job        | Toolchain                                       | BufferLength | Mean      | Error     | StdDev    | Median    | Min       | Max       | Ratio | RatioSD | Allocated | Alloc Ratio |
|----------------- |----------- |------------------------------------------------ |------------- |----------:|----------:|----------:|----------:|----------:|----------:|------:|--------:|----------:|------------:|
| LeadingZeroCount | Job-RJAKMA | Current PR                                      | 16           |  2.676 ns | 0.0525 ns | 0.0491 ns |  2.680 ns |  2.610 ns |  2.754 ns |  1.00 |    0.03 |         - |          NA |
| LeadingZeroCount | Job-FSPMRZ | Widen                                           | 16           |  3.485 ns | 0.0365 ns | 0.0342 ns |  3.502 ns |  3.428 ns |  3.526 ns |  1.30 |    0.03 |         - |          NA |

Other runs look like this:

| Method           | Job        | Toolchain                                       | BufferLength | Mean      | Error     | StdDev    | Median    | Min       | Max       | Ratio | RatioSD | Allocated | Alloc Ratio |
|----------------- |----------- |------------------------------------------------ |------------- |----------:|----------:|----------:|----------:|----------:|----------:|------:|--------:|----------:|------------:|
| LeadingZeroCount | Job-MGUUAK | Current PR                                      | 16           |  2.683 ns | 0.0424 ns | 0.0396 ns |  2.695 ns |  2.616 ns |  2.733 ns |  1.00 |    0.02 |         - |          NA |
| LeadingZeroCount | Job-NBPOWJ | Widen                                           | 16           |  2.484 ns | 0.0334 ns | 0.0296 ns |  2.492 ns |  2.427 ns |  2.519 ns |  0.93 |    0.02 |         - |          NA |

I chose this version because it was more consistent.

}

Debug.Assert(AdvSimd.IsSupported);
Expand All @@ -60,6 +85,29 @@ public static Vector256<T> Invoke(Vector256<T> x)
{
if (sizeof(T) == 4) return Avx512CD.VL.LeadingZeroCount(x.AsUInt32()).As<uint, T>();
if (sizeof(T) == 8) return Avx512CD.VL.LeadingZeroCount(x.AsUInt64()).As<ulong, T>();
if (sizeof(T) == 2)
{
Vector256<uint> lowHalf = Vector256.Create((uint)0x0000FFFF);
Vector256<uint> x_bot16 = Avx2.Or(Avx2.ShiftLeftLogical(x.AsUInt32(), 16), lowHalf);
Vector256<uint> x_top16 = Avx2.Or(x.AsUInt32(), lowHalf);
Vector256<uint> lz_bot16 = Avx512CD.VL.LeadingZeroCount(x_bot16);
Vector256<uint> lz_top16 = Avx512CD.VL.LeadingZeroCount(x_top16);
Vector256<uint> lz_top16_shift = Avx2.ShiftLeftLogical(lz_top16, 16);
return Avx2.Or(lz_bot16, lz_top16_shift).AsUInt16().As<ushort, T>();
Comment on lines +111 to +117
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar question as previous, widening to Vector512

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Widening-unwidening has similar performance in this case. Can be verified with BufferLength=16:

| Type                                      | Method           | Job        | Toolchain                                       | BufferLength | Mean      | Error     | StdDev    | Median    | Min       | Max       | Ratio | RatioSD | Allocated | Alloc Ratio |
|------------------------------------------ |----------------- |----------- |------------------------------------------------ |------------- |----------:|----------:|----------:|----------:|----------:|----------:|------:|--------:|----------:|------------:|
| Perf_BinaryIntegerTensorPrimitives<Int16> | LeadingZeroCount | Job-WWKLJQ | Current PR                                      | 16           |  2.485 ns | 0.0442 ns | 0.0414 ns |  2.496 ns |  2.410 ns |  2.530 ns |  1.00 |    0.02 |         - |          NA |
| Perf_BinaryIntegerTensorPrimitives<Int16> | LeadingZeroCount | Job-VFJFZO | Widen                                           | 16           |  2.474 ns | 0.0542 ns | 0.0507 ns |  2.495 ns |  2.402 ns |  2.529 ns |  1.00 |    0.03 |         - |          NA |

}
}
if (Avx512Vbmi.VL.IsSupported && sizeof(T) == 1)
{
Vector256<byte> lookupVector =
Vector256.Create((byte)8, 7, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4,
3, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0);
Vector256<byte> nibbleMask = Vector256.Create<byte>(0xF);
Vector256<byte> lowNibble = x.AsByte() & nibbleMask;
Vector256<byte> highNibble = Avx2.ShiftRightLogical(x.AsInt32(), 4).AsByte() & nibbleMask;
Vector256<byte> nibbleSelectMask = Avx2.CompareEqual(highNibble, Vector256<byte>.Zero);
Vector256<byte> indexVector = Avx2.BlendVariable(highNibble, lowNibble, nibbleSelectMask) +
(~nibbleSelectMask & nibbleMask);
return Avx512Vbmi.VL.PermuteVar32x8(lookupVector, indexVector).As<byte, T>();
}

return Vector256.Create(Invoke(x.GetLower()), Invoke(x.GetUpper()));
Expand All @@ -72,6 +120,31 @@ public static Vector512<T> Invoke(Vector512<T> x)
{
if (sizeof(T) == 4) return Avx512CD.LeadingZeroCount(x.AsUInt32()).As<uint, T>();
if (sizeof(T) == 8) return Avx512CD.LeadingZeroCount(x.AsUInt64()).As<ulong, T>();
if (sizeof(T) == 2)
{
Vector512<uint> lowHalf = Vector512.Create((uint)0x0000FFFF);
Vector512<uint> x_bot16 = Avx512F.Or(Avx512F.ShiftLeftLogical(x.AsUInt32(), 16), lowHalf);
Vector512<uint> x_top16 = Avx512F.Or(x.AsUInt32(), lowHalf);
Vector512<uint> lz_bot16 = Avx512CD.LeadingZeroCount(x_bot16);
Vector512<uint> lz_top16 = Avx512CD.LeadingZeroCount(x_top16);
Vector512<uint> lz_top16_shift = Avx512F.ShiftLeftLogical(lz_top16, 16);
return Avx512F.Or(lz_bot16, lz_top16_shift).AsUInt16().As<ushort, T>();
}
}
if (Avx512BW.IsSupported && Avx512Vbmi.IsSupported && sizeof(T) == 1)
{
Vector512<byte> lookupVectorA =
Vector512.Create((byte)8, 7, 6, 6, 5, 5, 5, 5,
4, 4, 4, 4, 4, 4, 4, 4,
3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2);
Vector512<byte> lookupVectorB = Vector512.Create((byte)1);
Vector512<byte> bit7ZeroMask = Avx512BW.CompareLessThan(x.AsByte(), Vector512.Create((byte)128));
return Avx512F.And(bit7ZeroMask, Avx512Vbmi.PermuteVar64x8x2(lookupVectorA, x.AsByte(), lookupVectorB)).As<byte, T>();
Comment on lines +144 to +155
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the only one that shouldn't be simply Widen+Lzcnt. But it does warrant a comment elaborating on how the lookup works.

In particular, PermuteVar64x8x2 isn't immediately obvious how it operates, so elaborating that x is being used as an index where bit 6 selects the table, bits 5:0 select an index in the table, and anything where bit 7 is set is zeroed is goodness.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I've added a comment to better explain how x is being used as an index and how the intrinsic is choosing between the two lookup vectors.

}

return Vector512.Create(Invoke(x.GetLower()), Invoke(x.GetUpper()));
Expand Down