Skip to content

Commit

Permalink
addressing PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
benwtrent committed Nov 4, 2024
1 parent a196c9b commit 2c33bd9
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,30 @@
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;

import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;

public class ESVectorUtil {

/**
* For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time.
* On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when
* compared to Integer::bitCount. While Long::bitCount is optimal on x64. See
* https://bugs.openjdk.org/browse/JDK-8336000
*/
static final boolean XOR_BIT_COUNT_STRIDE_AS_INT = Constants.OS_ARCH.equals("aarch64");
private static final MethodHandle BIT_COUNT_MH;
static {
try {
// For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time.
// On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when
// compared to Integer::bitCount. While Long::bitCount is optimal on x64. See
// https://bugs.openjdk.org/browse/JDK-8336000
BIT_COUNT_MH = Constants.OS_ARCH.equals("aarch64")
? MethodHandles.lookup()
.findStatic(ESVectorUtil.class, "andBitCountInt", MethodType.methodType(int.class, byte[].class, byte[].class))
: MethodHandles.lookup()
.findStatic(ESVectorUtil.class, "andBitCountLong", MethodType.methodType(int.class, byte[].class, byte[].class));
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new AssertionError(e);
}
}

private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();

Expand Down Expand Up @@ -84,7 +97,7 @@ public static float ipFloatBit(float[] q, byte[] d) {

/**
* AND bit count computed over signed bytes.
*
* Copied from Lucene's XOR implementation
* @param a bytes containing a vector
* @param b bytes containing another vector, of the same dimension
* @return the value of the AND bit count of the two vectors
Expand All @@ -93,16 +106,23 @@ public static int andBitCount(byte[] a, byte[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
if (XOR_BIT_COUNT_STRIDE_AS_INT) {
return andBitCountInt(a, b);
} else {
return andBitCountLong(a, b);
try {
return (int) BIT_COUNT_MH.invokeExact(a, b);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}

/** AND bit count striding over 4 bytes at a time. */
static int andBitCountInt(byte[] a, byte[] b) {
int distance = 0, i = 0;
// limit to number of int values in the array iterating by int byte views
for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) & (int) BitUtil.VH_NATIVE_INT.get(b, i));
}
Expand All @@ -113,9 +133,10 @@ static int andBitCountInt(byte[] a, byte[] b) {
return distance;
}

/** AND bit count striding over 8 bytes at a time. */
/** AND bit count striding over 8 bytes at a time**/
static int andBitCountLong(byte[] a, byte[] b) {
int distance = 0, i = 0;
// limit to number of long values in the array iterating by long byte views
for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) & (long) BitUtil.VH_NATIVE_LONG.get(b, i));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider();
static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();

public void testBitAndCount() {
testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
}

public void testIpByteBinInvariants() {
int iterations = atLeast(10);
for (int i = 0; i < iterations; i++) {
Expand All @@ -41,6 +45,23 @@ interface IpByteBin {
long apply(byte[] q, byte[] d);
}

interface BitOps {
long apply(byte[] q, byte[] d);
}

void testBasicBitAndImpl(BitOps bitAnd) {
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 0 }));
assertEquals(0, bitAnd.apply(new byte[] { 1 }, new byte[] { 0 }));
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 1 }));
assertEquals(1, bitAnd.apply(new byte[] { 1 }, new byte[] { 1 }));
byte[] a = new byte[31];
byte[] b = new byte[31];
random().nextBytes(a);
random().nextBytes(b);
int expected = scalarBitAnd(a, b);
assertEquals(expected, bitAnd.apply(a, b));
}

void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) {
assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 }));
assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 }));
Expand Down Expand Up @@ -115,6 +136,14 @@ static int scalarIpByteBin(byte[] q, byte[] d) {
return res;
}

static int scalarBitAnd(byte[] a, byte[] b) {
int res = 0;
for (int i = 0; i < a.length; i++) {
res += Integer.bitCount((a[i] & b[i]) & 0xFF);
}
return res;
}

public static int popcount(byte[] a, int aOffset, byte[] b, int length) {
int res = 0;
for (int j = 0; j < length; j++) {
Expand Down

0 comments on commit 2c33bd9

Please sign in to comment.