diff --git a/docs/changelog/116082.yaml b/docs/changelog/116082.yaml new file mode 100644 index 0000000000000..35ca5fb1ea82e --- /dev/null +++ b/docs/changelog/116082.yaml @@ -0,0 +1,5 @@ +pr: 116082 +summary: Add support for bitwise inner-product in painless +area: Vector Search +type: enhancement +issues: [] diff --git a/docs/reference/vectors/vector-functions.asciidoc b/docs/reference/vectors/vector-functions.asciidoc index 2a80290cf9d3b..10dca8084e28a 100644 --- a/docs/reference/vectors/vector-functions.asciidoc +++ b/docs/reference/vectors/vector-functions.asciidoc @@ -16,7 +16,7 @@ This is the list of available vector functions and vector access methods: 6. <].vectorValue`>> – returns a vector's value as an array of floats 7. <].magnitude`>> – returns a vector's magnitude -NOTE: The `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors. +NOTE: The `cosineSimilarity` function is not supported for `bit` vectors. NOTE: The recommended way to access dense vectors is through the `cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note @@ -332,6 +332,92 @@ When using `bit` vectors, not all the vector functions are available. The suppor * <> – calculates Hamming distance, the sum of the bitwise XOR of the two vectors * <> – calculates L^1^ distance, this is simply the `hamming` distance * <> - calculates L^2^ distance, this is the square root of the `hamming` distance +* <> – calculates dot product. When comparing two `bit` vectors, +this is the sum of the bitwise AND of the two vectors. If providing `float[]` or `byte[]`, who has `dims` number of elements, as a query vector, the `dotProduct` is +the sum of the floating point values using the stored `bit` vector as a mask. -Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors. +Here is an example of using dot-product with bit vectors. + +[source,console] +-------------------------------------------------- +PUT my-index-bit-vectors +{ + "mappings": { + "properties": { + "my_dense_vector": { + "type": "dense_vector", + "index": false, + "element_type": "bit", + "dims": 40 <1> + } + } + } +} + +PUT my-index-bit-vectors/_doc/1 +{ + "my_dense_vector": [8, 5, -15, 1, -7] <2> +} + +PUT my-index-bit-vectors/_doc/2 +{ + "my_dense_vector": [-1, 115, -3, 4, -128] +} + +PUT my-index-bit-vectors/_doc/3 +{ + "my_dense_vector": [2, 18, -5, 0, -124] +} + +POST my-index-bit-vectors/_refresh +-------------------------------------------------- +// TEST[continued] +<1> The number of dimensions or bits for the `bit` vector. +<2> This vector represents 5 bytes, or `5 * 8 = 40` bits, which equals the configured dimensions + +[source,console] +-------------------------------------------------- +GET my-index-bit-vectors/_search +{ + "query": { + "script_score": { + "query" : { + "match_all": {} + }, + "script": { + "source": "dotProduct(params.query_vector, 'my_dense_vector')", + "params": { + "query_vector": [8, 5, -15, 1, -7] <1> + } + } + } + } +} +-------------------------------------------------- +// TEST[continued] +<1> This vector is 40 bits, and thus will compute a bitwise `&` operation with the stored vectors. + +[source,console] +-------------------------------------------------- +GET my-index-bit-vectors/_search +{ + "query": { + "script_score": { + "query" : { + "match_all": {} + }, + "script": { + "source": "dotProduct(params.query_vector, 'my_dense_vector')", + "params": { + "query_vector": [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67] <1> + } + } + } + } +} +-------------------------------------------------- +// TEST[continued] +<1> This vector is 40 individual dimensions, and thus will sum the floating point values using the stored `bit` vector as a mask. + +Currently, the `cosineSimilarity` function is not supported for `bit` vectors. diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index 91193d5fa6eaf..de2cb9042610b 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -9,13 +9,36 @@ package org.elasticsearch.simdvec; +import org.apache.lucene.util.BitUtil; +import org.apache.lucene.util.Constants; import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport; import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; + import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY; public class ESVectorUtil { + private static final MethodHandle BIT_COUNT_MH; + static { + try { + // For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time. + // On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when + // compared to Integer::bitCount. While Long::bitCount is optimal on x64. See + // https://bugs.openjdk.org/browse/JDK-8336000 + BIT_COUNT_MH = Constants.OS_ARCH.equals("aarch64") + ? MethodHandles.lookup() + .findStatic(ESVectorUtil.class, "andBitCountInt", MethodType.methodType(int.class, byte[].class, byte[].class)) + : MethodHandles.lookup() + .findStatic(ESVectorUtil.class, "andBitCountLong", MethodType.methodType(int.class, byte[].class, byte[].class)); + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new AssertionError(e); + } + } + private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport(); public static long ipByteBinByte(byte[] q, byte[] d) { @@ -24,4 +47,103 @@ public static long ipByteBinByte(byte[] q, byte[] d) { } return IMPL.ipByteBinByte(q, d); } + + /** + * Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector. + * This will return the sum of the query vector values using the document vector as a mask. + * @param q the query vector + * @param d the document vector + * @return the inner product of the two vectors + */ + public static int ipByteBit(byte[] q, byte[] d) { + if (q.length != d.length * Byte.SIZE) { + throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length); + } + int result = 0; + // now combine the two vectors, summing the byte dimensions where the bit in d is `1` + for (int i = 0; i < d.length; i++) { + byte mask = d[i]; + for (int j = 0; j < Byte.SIZE; j++) { + if ((mask & (1 << j)) != 0) { + result += q[i * Byte.SIZE + j]; + } + } + } + return result; + } + + /** + * Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector. + * This will return the sum of the query vector values using the document vector as a mask. + * @param q the query vector + * @param d the document vector + * @return the inner product of the two vectors + */ + public static float ipFloatBit(float[] q, byte[] d) { + if (q.length != d.length * Byte.SIZE) { + throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length); + } + float result = 0; + for (int i = 0; i < d.length; i++) { + byte mask = d[i]; + for (int j = 0; j < Byte.SIZE; j++) { + if ((mask & (1 << j)) != 0) { + result += q[i * Byte.SIZE + j]; + } + } + } + return result; + } + + /** + * AND bit count computed over signed bytes. + * Copied from Lucene's XOR implementation + * @param a bytes containing a vector + * @param b bytes containing another vector, of the same dimension + * @return the value of the AND bit count of the two vectors + */ + public static int andBitCount(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); + } + try { + return (int) BIT_COUNT_MH.invokeExact(a, b); + } catch (Throwable e) { + if (e instanceof Error err) { + throw err; + } else if (e instanceof RuntimeException re) { + throw re; + } else { + throw new RuntimeException(e); + } + } + } + + /** AND bit count striding over 4 bytes at a time. */ + static int andBitCountInt(byte[] a, byte[] b) { + int distance = 0, i = 0; + // limit to number of int values in the array iterating by int byte views + for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) & (int) BitUtil.VH_NATIVE_INT.get(b, i)); + } + // tail: + for (; i < a.length; i++) { + distance += Integer.bitCount((a[i] & b[i]) & 0xFF); + } + return distance; + } + + /** AND bit count striding over 8 bytes at a time**/ + static int andBitCountLong(byte[] a, byte[] b) { + int distance = 0, i = 0; + // limit to number of long values in the array iterating by long byte views + for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) & (long) BitUtil.VH_NATIVE_LONG.get(b, i)); + } + // tail: + for (; i < a.length; i++) { + distance += Integer.bitCount((a[i] & b[i]) & 0xFF); + } + return distance; + } } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index 0dbc41c0c1055..e9e0fd58f7638 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -21,6 +21,10 @@ public class ESVectorUtilTests extends BaseVectorizationTests { static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider(); static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider(); + public void testBitAndCount() { + testBasicBitAndImpl(ESVectorUtil::andBitCountLong); + } + public void testIpByteBinInvariants() { int iterations = atLeast(10); for (int i = 0; i < iterations; i++) { @@ -41,6 +45,23 @@ interface IpByteBin { long apply(byte[] q, byte[] d); } + interface BitOps { + long apply(byte[] q, byte[] d); + } + + void testBasicBitAndImpl(BitOps bitAnd) { + assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 0 })); + assertEquals(0, bitAnd.apply(new byte[] { 1 }, new byte[] { 0 })); + assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 1 })); + assertEquals(1, bitAnd.apply(new byte[] { 1 }, new byte[] { 1 })); + byte[] a = new byte[31]; + byte[] b = new byte[31]; + random().nextBytes(a); + random().nextBytes(b); + int expected = scalarBitAnd(a, b); + assertEquals(expected, bitAnd.apply(a, b)); + } + void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) { assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 })); assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 })); @@ -115,6 +136,14 @@ static int scalarIpByteBin(byte[] q, byte[] d) { return res; } + static int scalarBitAnd(byte[] a, byte[] b) { + int res = 0; + for (int i = 0; i < a.length; i++) { + res += Integer.bitCount((a[i] & b[i]) & 0xFF); + } + return res; + } + public static int popcount(byte[] a, int aOffset, byte[] b, int length) { int res = 0; for (int j = 0; j < length; j++) { diff --git a/modules/lang-painless/build.gradle b/modules/lang-painless/build.gradle index e8751e09045d4..0bcc993c3d4e4 100644 --- a/modules/lang-painless/build.gradle +++ b/modules/lang-painless/build.gradle @@ -53,7 +53,7 @@ tasks.named("dependencyLicenses").configure { restResources { restApi { include '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'bulk', 'update', - 'scripts_painless_execute', 'put_script', 'delete_script' + 'scripts_painless_execute', 'put_script', 'delete_script', 'capabilities' } } diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml index 4c195a0e32623..2ee38f849e9d4 100644 --- a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml +++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml @@ -101,9 +101,15 @@ setup: - match: {hits.hits.2._id: "3"} - close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}} - --- "Dot Product is not supported": + - skip: + features: [capabilities] + capabilities: + - method: POST + path: /_search + capabilities: [ byte_float_bit_dot_product ] + reason: Capability required to run test - do: catch: bad_request headers: @@ -131,7 +137,6 @@ setup: source: "dotProduct(params.query_vector, 'vector')" params: query_vector: "006ff30e84" - --- "Cosine Similarity is not supported": - do: @@ -388,3 +393,119 @@ setup: - match: {hits.hits.2._id: "3"} - match: {hits.hits.2._score: 11.0} +--- +"Dot product with float": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ byte_float_bit_dot_product ] + test_runner_features: [capabilities, close_to] + reason: Capability required to run test + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: { match_all: { } } + script: + source: "dotProduct(params.query_vector, 'vector')" + params: + query_vector: [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67] + + - match: { hits.total: 3 } + + - match: {hits.hits.0._id: "2"} + - close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}} + + - match: {hits.hits.1._id: "3"} + - close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}} + + - match: {hits.hits.2._id: "1"} + - close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}} + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: { match_all: { } } + script: + source: "dotProduct(params.query_vector, 'indexed_vector')" + params: + query_vector: [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67] + + - match: { hits.total: 3 } + + - match: {hits.hits.0._id: "2"} + - close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}} + + - match: {hits.hits.1._id: "3"} + - close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}} + + - match: {hits.hits.2._id: "1"} + - close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}} +--- +"Dot product with byte": + - requires: + capabilities: + - method: POST + path: /_search + capabilities: [ byte_float_bit_dot_product ] + test_runner_features: capabilities + reason: Capability required to run test + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: { match_all: { } } + script: + source: "dotProduct(params.query_vector, 'vector')" + params: + query_vector: [12, -34, 56, -78, 90, 12, 34, -56, 78, -90, 23, -45, 67, -89, 12, 34, 56, 78, 90, -12, 34, -56, 78, -90, 23, -45, 67, -89, 12, -34, 56, -78, 90, -12, 34, -56, 78, 90, 23, -45] + + - match: { hits.total: 3 } + + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.0._score: 248} + + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.1._score: 136} + + - match: {hits.hits.2._id: "3"} + - match: {hits.hits.2._score: 20} + + - do: + headers: + Content-Type: application/json + search: + rest_total_hits_as_int: true + body: + query: + script_score: + query: { match_all: { } } + script: + source: "dotProduct(params.query_vector, 'indexed_vector')" + params: + query_vector: [12, -34, 56, -78, 90, 12, 34, -56, 78, -90, 23, -45, 67, -89, 12, 34, 56, 78, 90, -12, 34, -56, 78, -90, 23, -45, 67, -89, 12, -34, 56, -78, 90, -12, 34, -56, 78, 90, 23, -45] + + - match: { hits.total: 3 } + + - match: {hits.hits.0._id: "1"} + - match: {hits.hits.0._score: 248} + + - match: {hits.hits.1._id: "2"} + - match: {hits.hits.1._score: 136} + + - match: {hits.hits.2._id: "3"} + - match: {hits.hits.2._score: 20} diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java index 7828bb956a160..4efdfc66e8b5e 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java @@ -22,9 +22,12 @@ private SearchCapabilities() {} private static final String RANGE_REGEX_INTERVAL_QUERY_CAPABILITY = "range_regexp_interval_queries"; /** Support synthetic source with `bit` type in `dense_vector` field when `index` is set to `false`. */ private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source"; + /** Support Byte and Float with Bit dot product. */ + private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product"; public static final Set CAPABILITIES = Set.of( RANGE_REGEX_INTERVAL_QUERY_CAPABILITY, - BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY + BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY, + BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY ); } diff --git a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java index 809e9811f3673..e773bceb5ec05 100644 --- a/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java +++ b/server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java @@ -307,6 +307,87 @@ public interface DotProductInterface { double dotProduct(); } + public static class BitDotProduct extends DenseVectorFunction implements DotProductInterface { + private final byte[] byteQueryVector; + private final float[] floatQueryVector; + + public BitDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) { + super(scoreScript, field); + if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) { + throw new IllegalArgumentException("cannot calculate bit dot product for non-bit vectors"); + } + int fieldDims = field.get().getDims(); + if (fieldDims != queryVector.length * Byte.SIZE && fieldDims != queryVector.length) { + throw new IllegalArgumentException( + "The query vector has an incorrect number of dimensions. Must be [" + + fieldDims / 8 + + "] for bitwise operations, or [" + + fieldDims + + "] for byte wise operations: provided [" + + queryVector.length + + "]." + ); + } + this.byteQueryVector = queryVector; + this.floatQueryVector = null; + } + + public BitDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List queryVector) { + super(scoreScript, field); + if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) { + throw new IllegalArgumentException("cannot calculate bit dot product for non-bit vectors"); + } + float[] floatQueryVector = new float[queryVector.size()]; + byte[] byteQueryVector = new byte[queryVector.size()]; + boolean isFloat = false; + for (int i = 0; i < queryVector.size(); i++) { + Number number = queryVector.get(i); + floatQueryVector[i] = number.floatValue(); + byteQueryVector[i] = number.byteValue(); + if (isFloat + || floatQueryVector[i] % 1.0f != 0.0f + || floatQueryVector[i] < Byte.MIN_VALUE + || floatQueryVector[i] > Byte.MAX_VALUE) { + isFloat = true; + } + } + int fieldDims = field.get().getDims(); + if (isFloat) { + this.floatQueryVector = floatQueryVector; + this.byteQueryVector = null; + if (fieldDims != floatQueryVector.length) { + throw new IllegalArgumentException( + "The query vector has an incorrect number of dimensions. Must be [" + + fieldDims + + "] for float wise operations: provided [" + + floatQueryVector.length + + "]." + ); + } + } else { + this.floatQueryVector = null; + this.byteQueryVector = byteQueryVector; + if (fieldDims != byteQueryVector.length * Byte.SIZE && fieldDims != byteQueryVector.length) { + throw new IllegalArgumentException( + "The query vector has an incorrect number of dimensions. Must be [" + + fieldDims / 8 + + "] for bitwise operations, or [" + + fieldDims + + "] for byte wise operations: provided [" + + byteQueryVector.length + + "]." + ); + } + } + } + + @Override + public double dotProduct() { + setNextVector(); + return byteQueryVector != null ? field.get().dotProduct(byteQueryVector) : field.get().dotProduct(floatQueryVector); + } + } + public static class ByteDotProduct extends ByteDenseVectorFunction implements DotProductInterface { public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List queryVector) { @@ -343,7 +424,16 @@ public static final class DotProduct { public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) { DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName); function = switch (field.getElementType()) { - case BYTE, BIT -> { + case BIT -> { + if (queryVector instanceof List) { + yield new BitDotProduct(scoreScript, field, (List) queryVector); + } else if (queryVector instanceof String s) { + byte[] parsedQueryVector = HexFormat.of().parseHex(s); + yield new BitDotProduct(scoreScript, field, parsedQueryVector); + } + throw new IllegalArgumentException("Unsupported input object for bit vectors: " + queryVector.getClass().getName()); + } + case BYTE -> { if (queryVector instanceof List) { yield new ByteDotProduct(scoreScript, field, (List) queryVector); } else if (queryVector instanceof String s) { diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitBinaryDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitBinaryDenseVector.java index 9c0b7ce2e5d6e..fecca9c1b3929 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/BitBinaryDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitBinaryDenseVector.java @@ -13,6 +13,10 @@ import java.util.List; +import static org.elasticsearch.simdvec.ESVectorUtil.andBitCount; +import static org.elasticsearch.simdvec.ESVectorUtil.ipByteBit; +import static org.elasticsearch.simdvec.ESVectorUtil.ipFloatBit; + public class BitBinaryDenseVector extends ByteBinaryDenseVector { public BitBinaryDenseVector(byte[] vectorValue, BytesRef docVector, int dims) { @@ -54,7 +58,11 @@ public double l2Norm(List queryVector) { @Override public int dotProduct(byte[] queryVector) { - throw new UnsupportedOperationException("dotProduct is not supported for bit vectors."); + if (queryVector.length == vectorValue.length) { + // assume that the query vector is a bit vector and do a bitwise AND + return andBitCount(vectorValue, queryVector); + } + return ipByteBit(queryVector, vectorValue); } @Override @@ -79,7 +87,7 @@ public double cosineSimilarity(List queryVector) { @Override public double dotProduct(float[] queryVector) { - throw new UnsupportedOperationException("dotProduct is not supported for bit vectors."); + return ipFloatBit(queryVector, vectorValue); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/BitKnnDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/BitKnnDenseVector.java index b0171325d4089..fcfc4546f6e73 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/BitKnnDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/BitKnnDenseVector.java @@ -11,6 +11,10 @@ import java.util.List; +import static org.elasticsearch.simdvec.ESVectorUtil.andBitCount; +import static org.elasticsearch.simdvec.ESVectorUtil.ipByteBit; +import static org.elasticsearch.simdvec.ESVectorUtil.ipFloatBit; + public class BitKnnDenseVector extends ByteKnnDenseVector { public BitKnnDenseVector(byte[] vector) { @@ -61,7 +65,11 @@ public double l2Norm(List queryVector) { @Override public int dotProduct(byte[] queryVector) { - throw new UnsupportedOperationException("dotProduct is not supported for bit vectors."); + if (queryVector.length == docVector.length) { + // assume that the query vector is a bit vector and do a bitwise AND + return andBitCount(docVector, queryVector); + } + return ipByteBit(queryVector, docVector); } @Override @@ -86,7 +94,7 @@ public double cosineSimilarity(List queryVector) { @Override public double dotProduct(float[] queryVector) { - throw new UnsupportedOperationException("dotProduct is not supported for bit vectors."); + return ipFloatBit(queryVector, docVector); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java index a01d1fcbdb4ed..9593f61fcba65 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java @@ -21,7 +21,7 @@ public class ByteBinaryDenseVector implements DenseVector { public static final int MAGNITUDE_BYTES = 4; private final BytesRef docVector; - private final byte[] vectorValue; + protected final byte[] vectorValue; protected final int dims; private float[] floatDocVector; diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/BinaryDenseVectorScriptDocValuesTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/BinaryDenseVectorScriptDocValuesTests.java index d5360afddc3ad..7f67cce38c5d5 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/BinaryDenseVectorScriptDocValuesTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/BinaryDenseVectorScriptDocValuesTests.java @@ -236,15 +236,19 @@ public long cost() { } public static BytesRef mockEncodeDenseVector(float[] values, ElementType elementType, IndexVersion indexVersion) { + int dims = values.length; + if (elementType == ElementType.BIT) { + dims *= Byte.SIZE; + } int numBytes = indexVersion.onOrAfter(DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION) - ? elementType.getNumBytes(values.length) + DenseVectorFieldMapper.MAGNITUDE_BYTES - : elementType.getNumBytes(values.length); + ? elementType.getNumBytes(dims) + DenseVectorFieldMapper.MAGNITUDE_BYTES + : elementType.getNumBytes(dims); double dotProduct = 0f; ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes); for (float value : values) { if (elementType == ElementType.FLOAT) { byteBuffer.putFloat(value); - } else if (elementType == ElementType.BYTE) { + } else if (elementType == ElementType.BYTE || elementType == ElementType.BIT) { byteBuffer.put((byte) value); } else { throw new IllegalStateException("unknown element_type [" + elementType + "]"); diff --git a/server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java b/server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java index e5ebcf8b3303d..2d9caca1ba6a1 100644 --- a/server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java +++ b/server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java @@ -20,6 +20,8 @@ import org.elasticsearch.script.VectorScoreScriptUtils.L1Norm; import org.elasticsearch.script.VectorScoreScriptUtils.L2Norm; import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.BitBinaryDenseVectorDocValuesField; +import org.elasticsearch.script.field.vectors.BitKnnDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.ByteBinaryDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.ByteKnnDenseVectorDocValuesField; import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField; @@ -229,6 +231,61 @@ public void testByteVectorClassBindings() throws IOException { } } + public void testBitVectorClassBindingsDotProduct() throws IOException { + String fieldName = "vector"; + int dims = 8; + float[] docVector = new float[] { 124 }; + // 124 in binary is b01111100 + List queryVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12); + List floatQueryVector = Arrays.asList(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f); + List invalidQueryVector = Arrays.asList((byte) 1, (byte) 1); + String hexidecimalString = HexFormat.of().formatHex(new byte[] { 124 }); + + List fields = List.of( + new BitBinaryDenseVectorDocValuesField( + BinaryDenseVectorScriptDocValuesTests.wrap(new float[][] { docVector }, ElementType.BIT, IndexVersion.current()), + "test", + ElementType.BIT, + dims + ), + new BitKnnDenseVectorDocValuesField(KnnDenseVectorScriptDocValuesTests.wrapBytes(new float[][] { docVector }), "test", dims) + ); + for (DenseVectorDocValuesField field : fields) { + field.setNextDocId(0); + + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript.field(fieldName)).thenAnswer(mock -> field); + + // Test cosine similarity explicitly, as it must perform special logic on top of the doc values + DotProduct function = new DotProduct(scoreScript, queryVector, fieldName); + assertEquals("dotProduct result is not equal to the expected value!", -12 + 2 + 4 + 1 + 125, function.dotProduct(), 0.001); + + function = new DotProduct(scoreScript, floatQueryVector, fieldName); + assertEquals( + "dotProduct result is not equal to the expected value!", + 0.42f + 0f + 1f - 1f - 0.42f, + function.dotProduct(), + 0.001 + ); + + function = new DotProduct(scoreScript, hexidecimalString, fieldName); + assertEquals("dotProduct result is not equal to the expected value!", Integer.bitCount(124), function.dotProduct(), 0.0); + + // Check each function rejects query vectors with the wrong dimension + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> new DotProduct(scoreScript, invalidQueryVector, fieldName) + ); + assertThat( + e.getMessage(), + containsString( + "query vector has an incorrect number of dimensions. " + + "Must be [1] for bitwise operations, or [8] for byte wise operations: provided [2]." + ) + ); + } + } + public void testByteVsFloatSimilarity() throws IOException { int dims = 5; float[] docVector = new float[] { 1f, 127f, -128f, 5f, -10f };