diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java index 659cc89bfe46d..de91833c99842 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java @@ -16,11 +16,11 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.elasticsearch.script.field.vectors.ESVectorUtil; import java.io.IOException; @@ -100,7 +100,7 @@ public RandomVectorScorer getRandomVectorScorer( } static float hammingScore(byte[] a, byte[] b) { - return ((a.length * Byte.SIZE) - VectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE); + return ((a.length * Byte.SIZE) - ESVectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE); } static class HammingVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java index f2ff8fbccd2fb..e5c2d6a370f12 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java @@ -102,7 +102,7 @@ public double l1Norm(List queryVector) { @Override public int hamming(byte[] queryVector) { - return VectorUtil.xorBitCount(queryVector, vectorValue); + return ESVectorUtil.xorBitCount(queryVector, vectorValue); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java index e0ba032826aa1..0145eb3eae04b 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java @@ -103,7 +103,7 @@ public double l1Norm(List queryVector) { @Override public int hamming(byte[] queryVector) { - return VectorUtil.xorBitCount(queryVector, docVector); + return ESVectorUtil.xorBitCount(queryVector, docVector); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java new file mode 100644 index 0000000000000..7d9542bccf357 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.util.BitUtil; +import org.apache.lucene.util.Constants; + +/** + * This class consists of a single utility method that provides XOR bit count computed over signed bytes. + * Remove this class when Lucene version > 9.11 is released, and replace with Lucene's VectorUtil directly. + */ +public class ESVectorUtil { + + /** + * For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time. + * On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when + * compared to Integer::bitCount. While Long::bitCount is optimal on x64. + */ + static final boolean XOR_BIT_COUNT_STRIDE_AS_INT = Constants.OS_ARCH.equals("aarch64"); + + /** + * XOR bit count computed over signed bytes. + * + * @param a bytes containing a vector + * @param b bytes containing another vector, of the same dimension + * @return the value of the XOR bit count of the two vectors + */ + public static int xorBitCount(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); + } + if (XOR_BIT_COUNT_STRIDE_AS_INT) { + return xorBitCountInt(a, b); + } else { + return xorBitCountLong(a, b); + } + } + + /** XOR bit count striding over 4 bytes at a time. */ + static int xorBitCountInt(byte[] a, byte[] b) { + int distance = 0, i = 0; + for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) ^ (int) BitUtil.VH_NATIVE_INT.get(b, i)); + } + // tail: + for (; i < a.length; i++) { + distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF); + } + return distance; + } + + /** XOR bit count striding over 8 bytes at a time. */ + static int xorBitCountLong(byte[] a, byte[] b) { + int distance = 0, i = 0; + for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) ^ (long) BitUtil.VH_NATIVE_LONG.get(b, i)); + } + // tail: + for (; i < a.length; i++) { + distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF); + } + return distance; + } + + private ESVectorUtil() {} +}