Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for bitwise inner-product in painless #116082

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/116082.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 116082
summary: Add support for bitwise inner-product in painless
area: Vector Search
type: enhancement
issues: []
5 changes: 4 additions & 1 deletion docs/reference/vectors/vector-functions.asciidoc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 19 we also say that dot_product is not supported for bit vectors.

Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ When using `bit` vectors, not all the vector functions are available. The suppor
* <<vector-functions-hamming,`hamming`>> – calculates Hamming distance, the sum of the bitwise XOR of the two vectors
* <<vector-functions-l1,`l1norm`>> – calculates L^1^ distance, this is simply the `hamming` distance
* <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance, this is the square root of the `hamming` distance
* <<vector-functions-dot-product,`dotProduct`>> – calculates dot product. When comparing two `bit` vectors,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be we can add that queryVector can be byte[] (of the same dims as docs or dims *8), or also can be a string, and can be of float[]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

this is the sum of the bitwise AND of the two vectors. If providing `float[]` as a query vector, the `dotProduct` is
the sum of the floating point values using the stored `bit` vector as a mask.

Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
Currently, the `cosineSimilarity` function is not supported for `bit` vectors.

122 changes: 122 additions & 0 deletions libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,36 @@

package org.elasticsearch.simdvec;

import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;

import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;

public class ESVectorUtil {

private static final MethodHandle BIT_COUNT_MH;
static {
try {
// For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time.
// On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when
// compared to Integer::bitCount. While Long::bitCount is optimal on x64. See
// https://bugs.openjdk.org/browse/JDK-8336000
BIT_COUNT_MH = Constants.OS_ARCH.equals("aarch64")
? MethodHandles.lookup()
.findStatic(ESVectorUtil.class, "andBitCountInt", MethodType.methodType(int.class, byte[].class, byte[].class))
: MethodHandles.lookup()
.findStatic(ESVectorUtil.class, "andBitCountLong", MethodType.methodType(int.class, byte[].class, byte[].class));
} catch (NoSuchMethodException | IllegalAccessException e) {
throw new AssertionError(e);
}
}

private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();

public static long ipByteBinByte(byte[] q, byte[] d) {
Expand All @@ -24,4 +47,103 @@ public static long ipByteBinByte(byte[] q, byte[] d) {
}
return IMPL.ipByteBinByte(q, d);
}

/**
* Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector.
* This will return the sum of the query vector values using the document vector as a mask.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
*/
public static int ipByteBit(byte[] q, byte[] d) {
if (q.length != d.length * Byte.SIZE) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
}
int result = 0;
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
}
}
}
return result;
}

/**
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector.
* This will return the sum of the query vector values using the document vector as a mask.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
*/
public static float ipFloatBit(float[] q, byte[] d) {
if (q.length != d.length * Byte.SIZE) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
}
float result = 0;
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
}
}
}
return result;
}

/**
* AND bit count computed over signed bytes.
* Copied from Lucene's XOR implementation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more so for my education. What's the thinking here for putting this in ES vs Lucene? given that we have XOR in Lucene VectorUtil and this seem complementary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@john-wagster there is no compelling reason for keeping it out of Lucene. But, its weird for there to be public utility methods in Lucene when nothing directly utilizes it.

* @param a bytes containing a vector
* @param b bytes containing another vector, of the same dimension
* @return the value of the AND bit count of the two vectors
*/
public static int andBitCount(byte[] a, byte[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
try {
return (int) BIT_COUNT_MH.invokeExact(a, b);
} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}

/** AND bit count striding over 4 bytes at a time. */
static int andBitCountInt(byte[] a, byte[] b) {
int distance = 0, i = 0;
// limit to number of int values in the array iterating by int byte views
for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) & (int) BitUtil.VH_NATIVE_INT.get(b, i));
}
// tail:
for (; i < a.length; i++) {
distance += Integer.bitCount((a[i] & b[i]) & 0xFF);
}
return distance;
}

/** AND bit count striding over 8 bytes at a time**/
static int andBitCountLong(byte[] a, byte[] b) {
int distance = 0, i = 0;
// limit to number of long values in the array iterating by long byte views
for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) & (long) BitUtil.VH_NATIVE_LONG.get(b, i));
}
// tail:
for (; i < a.length; i++) {
distance += Integer.bitCount((a[i] & b[i]) & 0xFF);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the tail be done with a single Long.bitCount call, if using a mask based on the number of remaining bytes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly? But I didn't want to bother with over optimizing. Especially since these methods are effectively copy-pastes of what exists in Lucene for xor (just changing to &).

}
return distance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider();
static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();

public void testBitAndCount() {
testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
}

public void testIpByteBinInvariants() {
int iterations = atLeast(10);
for (int i = 0; i < iterations; i++) {
Expand All @@ -41,6 +45,23 @@ interface IpByteBin {
long apply(byte[] q, byte[] d);
}

interface BitOps {
long apply(byte[] q, byte[] d);
}

void testBasicBitAndImpl(BitOps bitAnd) {
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 0 }));
assertEquals(0, bitAnd.apply(new byte[] { 1 }, new byte[] { 0 }));
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 1 }));
assertEquals(1, bitAnd.apply(new byte[] { 1 }, new byte[] { 1 }));
byte[] a = new byte[31];
byte[] b = new byte[31];
random().nextBytes(a);
random().nextBytes(b);
int expected = scalarBitAnd(a, b);
assertEquals(expected, bitAnd.apply(a, b));
}

void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) {
assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 }));
assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 }));
Expand Down Expand Up @@ -115,6 +136,14 @@ static int scalarIpByteBin(byte[] q, byte[] d) {
return res;
}

static int scalarBitAnd(byte[] a, byte[] b) {
int res = 0;
for (int i = 0; i < a.length; i++) {
res += Integer.bitCount((a[i] & b[i]) & 0xFF);
}
return res;
}

public static int popcount(byte[] a, int aOffset, byte[] b, int length) {
int res = 0;
for (int j = 0; j < length; j++) {
Expand Down
2 changes: 1 addition & 1 deletion modules/lang-painless/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ tasks.named("dependencyLicenses").configure {
restResources {
restApi {
include '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'bulk', 'update',
'scripts_painless_execute', 'put_script', 'delete_script'
'scripts_painless_execute', 'put_script', 'delete_script', 'capabilities'
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,36 +102,6 @@ setup:
- match: {hits.hits.2._id: "3"}
- close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}}

---
"Dot Product is not supported":
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
body:
query:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"

---
"Cosine Similarity is not supported":
- do:
Expand Down Expand Up @@ -388,3 +358,119 @@ setup:

- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
---
"Dot product with float":
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ byte_float_bit_dot_product ]
test_runner_features: [capabilities, close_to]
reason: Capability required to run test
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: { match_all: { } }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67]

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}

- match: {hits.hits.2._id: "1"}
- close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}

- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: { match_all: { } }
script:
source: "dotProduct(params.query_vector, 'indexed_vector')"
params:
query_vector: [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67]

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}

- match: {hits.hits.2._id: "1"}
- close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}
---
"Dot product with byte":
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ byte_float_bit_dot_product ]
test_runner_features: capabilities
reason: Capability required to run test
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: { match_all: { } }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [12, -34, 56, -78, 90, 12, 34, -56, 78, -90, 23, -45, 67, -89, 12, 34, 56, 78, 90, -12, 34, -56, 78, -90, 23, -45, 67, -89, 12, -34, 56, -78, 90, -12, 34, -56, 78, 90, 23, -45]

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "1"}
- match: {hits.hits.0._score: 248}

- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1._score: 136}

- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 20}

- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: { match_all: { } }
script:
source: "dotProduct(params.query_vector, 'indexed_vector')"
params:
query_vector: [12, -34, 56, -78, 90, 12, 34, -56, 78, -90, 23, -45, 67, -89, 12, 34, 56, 78, 90, -12, 34, -56, 78, -90, 23, -45, 67, -89, 12, -34, 56, -78, 90, -12, 34, -56, 78, 90, 23, -45]

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "1"}
- match: {hits.hits.0._score: 248}

- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1._score: 136}

- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 20}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ private SearchCapabilities() {}
private static final String RANGE_REGEX_INTERVAL_QUERY_CAPABILITY = "range_regexp_interval_queries";
/** Support synthetic source with `bit` type in `dense_vector` field when `index` is set to `false`. */
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
/** Support Byte and Float with Bit dot product. */
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product";

public static final Set<String> CAPABILITIES = Set.of(
RANGE_REGEX_INTERVAL_QUERY_CAPABILITY,
BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY
BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY,
BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY
);
}
Loading