Skip to content

Commit

Permalink
Add support for bitwise inner-product in painless (elastic#116082)
Browse files Browse the repository at this point in the history
This adds bitwise inner product to painless. 

The idea here is:

 - For two bit arrays, which we determine to be a byte array whose dimensions match `dense_vector.dim/8`, we simply return bitwise `&`
 - For a stored bit array (remember, with `dense_vector.dim/8` bytes), sum up the provided byte or float array using the bit array as a mask.

This is effectively supporting asynchronous quantization. A prime
example of how this works is:
https://github.com/cohere-ai/BinaryVectorDB

Basically, you do your initial search against the binary space and then
rerank with a differently quantized vector allowing for more information
without additional storage space. 

closes:  elastic#111232
  • Loading branch information
benwtrent committed Nov 5, 2024
1 parent 175cb28 commit 203f130
Show file tree
Hide file tree
Showing 13 changed files with 552 additions and 15 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/116082.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 116082
summary: Add support for bitwise inner-product in painless
area: Vector Search
type: enhancement
issues: []
90 changes: 88 additions & 2 deletions docs/reference/vectors/vector-functions.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ This is the list of available vector functions and vector access methods:
6. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> – returns a vector's value as an array of floats
7. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> – returns a vector's magnitude

NOTE: The `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
NOTE: The `cosineSimilarity` function is not supported for `bit` vectors.

NOTE: The recommended way to access dense vectors is through the
`cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note
Expand Down Expand Up @@ -332,6 +332,92 @@ When using `bit` vectors, not all the vector functions are available. The suppor
* <<vector-functions-hamming,`hamming`>> – calculates Hamming distance, the sum of the bitwise XOR of the two vectors
* <<vector-functions-l1,`l1norm`>> – calculates L^1^ distance, this is simply the `hamming` distance
* <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance, this is the square root of the `hamming` distance
* <<vector-functions-dot-product,`dotProduct`>> – calculates dot product. When comparing two `bit` vectors,
this is the sum of the bitwise AND of the two vectors. If providing `float[]` or `byte[]`, who has `dims` number of elements, as a query vector, the `dotProduct` is
the sum of the floating point values using the stored `bit` vector as a mask.

Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
Here is an example of using dot-product with bit vectors.

[source,console]
--------------------------------------------------
PUT my-index-bit-vectors
{
"mappings": {
"properties": {
"my_dense_vector": {
"type": "dense_vector",
"index": false,
"element_type": "bit",
"dims": 40 <1>
}
}
}
}
PUT my-index-bit-vectors/_doc/1
{
"my_dense_vector": [8, 5, -15, 1, -7] <2>
}
PUT my-index-bit-vectors/_doc/2
{
"my_dense_vector": [-1, 115, -3, 4, -128]
}
PUT my-index-bit-vectors/_doc/3
{
"my_dense_vector": [2, 18, -5, 0, -124]
}
POST my-index-bit-vectors/_refresh
--------------------------------------------------
// TEST[continued]
<1> The number of dimensions or bits for the `bit` vector.
<2> This vector represents 5 bytes, or `5 * 8 = 40` bits, which equals the configured dimensions

[source,console]
--------------------------------------------------
GET my-index-bit-vectors/_search
{
"query": {
"script_score": {
"query" : {
"match_all": {}
},
"script": {
"source": "dotProduct(params.query_vector, 'my_dense_vector')",
"params": {
"query_vector": [8, 5, -15, 1, -7] <1>
}
}
}
}
}
--------------------------------------------------
// TEST[continued]
<1> This vector is 40 bits, and thus will compute a bitwise `&` operation with the stored vectors.

[source,console]
--------------------------------------------------
GET my-index-bit-vectors/_search
{
"query": {
"script_score": {
"query" : {
"match_all": {}
},
"script": {
"source": "dotProduct(params.query_vector, 'my_dense_vector')",
"params": {
"query_vector": [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67] <1>
}
}
}
}
}
--------------------------------------------------
// TEST[continued]
<1> This vector is 40 individual dimensions, and thus will sum the floating point values using the stored `bit` vector as a mask.

Currently, the `cosineSimilarity` function is not supported for `bit` vectors.

122 changes: 122 additions & 0 deletions libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,36 @@

package org.elasticsearch.simdvec;

import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;

import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;

public class ESVectorUtil {

private static final MethodHandle BIT_COUNT_MH;
static {
try {
// For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time.
// On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when
// compared to Integer::bitCount. While Long::bitCount is optimal on x64. See
// https://bugs.openjdk.org/browse/JDK-8336000
BIT_COUNT_MH = Constants.OS_ARCH.equals("aarch64")
? MethodHandles.lookup()
.findStatic(ESVectorUtil.class, "andBitCountInt", MethodType.methodType(int.class, byte[].class, byte[].class))
: MethodHandles.lookup()
.findStatic(ESVectorUtil.class, "andBitCountLong", MethodType.methodType(int.class, byte[].class, byte[].class));
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new AssertionError(e);
}
}

private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();

public static long ipByteBinByte(byte[] q, byte[] d) {
Expand All @@ -24,4 +47,103 @@ public static long ipByteBinByte(byte[] q, byte[] d) {
}
return IMPL.ipByteBinByte(q, d);
}

/**
* Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector.
* This will return the sum of the query vector values using the document vector as a mask.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
*/
public static int ipByteBit(byte[] q, byte[] d) {
if (q.length != d.length * Byte.SIZE) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
}
int result = 0;
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
}
}
}
return result;
}

/**
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector.
* This will return the sum of the query vector values using the document vector as a mask.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
*/
public static float ipFloatBit(float[] q, byte[] d) {
if (q.length != d.length * Byte.SIZE) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
}
float result = 0;
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
}
}
}
return result;
}

/**
* AND bit count computed over signed bytes.
* Copied from Lucene's XOR implementation
* @param a bytes containing a vector
* @param b bytes containing another vector, of the same dimension
* @return the value of the AND bit count of the two vectors
*/
public static int andBitCount(byte[] a, byte[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
try {
return (int) BIT_COUNT_MH.invokeExact(a, b);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}

/** AND bit count striding over 4 bytes at a time. */
static int andBitCountInt(byte[] a, byte[] b) {
int distance = 0, i = 0;
// limit to number of int values in the array iterating by int byte views
for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) & (int) BitUtil.VH_NATIVE_INT.get(b, i));
}
// tail:
for (; i < a.length; i++) {
distance += Integer.bitCount((a[i] & b[i]) & 0xFF);
}
return distance;
}

/** AND bit count striding over 8 bytes at a time**/
static int andBitCountLong(byte[] a, byte[] b) {
int distance = 0, i = 0;
// limit to number of long values in the array iterating by long byte views
for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) & (long) BitUtil.VH_NATIVE_LONG.get(b, i));
}
// tail:
for (; i < a.length; i++) {
distance += Integer.bitCount((a[i] & b[i]) & 0xFF);
}
return distance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider();
static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();

public void testBitAndCount() {
testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
}

public void testIpByteBinInvariants() {
int iterations = atLeast(10);
for (int i = 0; i < iterations; i++) {
Expand All @@ -41,6 +45,23 @@ interface IpByteBin {
long apply(byte[] q, byte[] d);
}

interface BitOps {
long apply(byte[] q, byte[] d);
}

void testBasicBitAndImpl(BitOps bitAnd) {
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 0 }));
assertEquals(0, bitAnd.apply(new byte[] { 1 }, new byte[] { 0 }));
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 1 }));
assertEquals(1, bitAnd.apply(new byte[] { 1 }, new byte[] { 1 }));
byte[] a = new byte[31];
byte[] b = new byte[31];
random().nextBytes(a);
random().nextBytes(b);
int expected = scalarBitAnd(a, b);
assertEquals(expected, bitAnd.apply(a, b));
}

void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) {
assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 }));
assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 }));
Expand Down Expand Up @@ -115,6 +136,14 @@ static int scalarIpByteBin(byte[] q, byte[] d) {
return res;
}

static int scalarBitAnd(byte[] a, byte[] b) {
int res = 0;
for (int i = 0; i < a.length; i++) {
res += Integer.bitCount((a[i] & b[i]) & 0xFF);
}
return res;
}

public static int popcount(byte[] a, int aOffset, byte[] b, int length) {
int res = 0;
for (int j = 0; j < length; j++) {
Expand Down
6 changes: 5 additions & 1 deletion modules/lang-painless/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ tasks.named("dependencyLicenses").configure {
mapping from: /asm-.*/, to: 'asm'
}

tasks.named("yamlRestCompatTestTransform").configure({ task ->
task.skipTest("painless/146_dense_vector_bit_basic/Dot Product is not supported", "inner product is now supported")
})

restResources {
restApi {
include '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'bulk', 'update',
'scripts_painless_execute', 'put_script', 'delete_script'
'scripts_painless_execute', 'put_script', 'delete_script', 'capabilities'
}
}

Expand Down
Loading

0 comments on commit 203f130

Please sign in to comment.