Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for bitwise inner-product in painless #116082

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/116082.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 116082
summary: Add support for bitwise inner-product in painless
area: Vector Search
type: enhancement
issues: []
5 changes: 4 additions & 1 deletion docs/reference/vectors/vector-functions.asciidoc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 19 we also say that dot_product is not supported for bit vectors.

Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ When using `bit` vectors, not all the vector functions are available. The suppor
* <<vector-functions-hamming,`hamming`>> – calculates Hamming distance, the sum of the bitwise XOR of the two vectors
* <<vector-functions-l1,`l1norm`>> – calculates L^1^ distance, this is simply the `hamming` distance
* <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance, this is the square root of the `hamming` distance
* <<vector-functions-dot-product,`dotProduct`>> – calculates dot product. When comparing two `bit` vectors,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be we can add that queryVector can be byte[] (of the same dims as docs or dims *8), or also can be a string, and can be of float[]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

this is the sum of the bitwise AND of the two vectors. If providing `float[]` as a query vector, the `dotProduct` is
the sum of the floating point values using the stored `bit` vector as a mask.

Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
Currently, the `cosineSimilarity` function is not supported for `bit` vectors.

101 changes: 101 additions & 0 deletions libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@

package org.elasticsearch.simdvec;

import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;

import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;

public class ESVectorUtil {

/**
* For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time.
* On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when
* compared to Integer::bitCount. While Long::bitCount is optimal on x64. See
* https://bugs.openjdk.org/browse/JDK-8336000
*/
static final boolean XOR_BIT_COUNT_STRIDE_AS_INT = Constants.OS_ARCH.equals("aarch64");

private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();

public static long ipByteBinByte(byte[] q, byte[] d) {
Expand All @@ -24,4 +34,95 @@ public static long ipByteBinByte(byte[] q, byte[] d) {
}
return IMPL.ipByteBinByte(q, d);
}

/**
* Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector.
* This will return the sum of the query vector values using the document vector as a mask.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
*/
public static int ipByteBit(byte[] q, byte[] d) {
if (q.length != d.length * Byte.SIZE) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
}
int result = 0;
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
}
}
}
return result;
}

/**
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector.
* This will return the sum of the query vector values using the document vector as a mask.
* @param q the query vector
* @param d the document vector
* @return the inner product of the two vectors
*/
public static float ipFloatBit(float[] q, byte[] d) {
if (q.length != d.length * Byte.SIZE) {
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
}
float result = 0;
for (int i = 0; i < d.length; i++) {
byte mask = d[i];
for (int j = 0; j < Byte.SIZE; j++) {
if ((mask & (1 << j)) != 0) {
result += q[i * Byte.SIZE + j];
}
}
}
return result;
}

/**
* AND bit count computed over signed bytes.
*
* @param a bytes containing a vector
* @param b bytes containing another vector, of the same dimension
* @return the value of the AND bit count of the two vectors
*/
public static int andBitCount(byte[] a, byte[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
}
if (XOR_BIT_COUNT_STRIDE_AS_INT) {
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
return andBitCountInt(a, b);
} else {
return andBitCountLong(a, b);
}
}

/** AND bit count striding over 4 bytes at a time. */
static int andBitCountInt(byte[] a, byte[] b) {
int distance = 0, i = 0;
for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) & (int) BitUtil.VH_NATIVE_INT.get(b, i));
}
// tail:
for (; i < a.length; i++) {
distance += Integer.bitCount((a[i] & b[i]) & 0xFF);
}
return distance;
}

/** AND bit count striding over 8 bytes at a time. */
static int andBitCountLong(byte[] a, byte[] b) {
int distance = 0, i = 0;
for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) & (long) BitUtil.VH_NATIVE_LONG.get(b, i));
}
// tail:
for (; i < a.length; i++) {
distance += Integer.bitCount((a[i] & b[i]) & 0xFF);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the tail be done with a single Long.bitCount call, if using a mask based on the number of remaining bytes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly? But I didn't want to bother with over optimizing. Especially since these methods are effectively copy-pastes of what exists in Lucene for xor (just changing to &).

}
return distance;
}
}
2 changes: 1 addition & 1 deletion modules/lang-painless/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ tasks.named("dependencyLicenses").configure {
restResources {
restApi {
include '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'bulk', 'update',
'scripts_painless_execute', 'put_script', 'delete_script'
'scripts_painless_execute', 'put_script', 'delete_script', 'capabilities'
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,36 +102,6 @@ setup:
- match: {hits.hits.2._id: "3"}
- close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}}

---
"Dot Product is not supported":
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
body:
query:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [0, 111, -13, 14, -124]
- do:
catch: bad_request
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: "006ff30e84"

---
"Cosine Similarity is not supported":
- do:
Expand Down Expand Up @@ -388,3 +358,119 @@ setup:

- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 11.0}
---
"Dot product with float":
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ byte_float_bit_dot_product ]
test_runner_features: [capabilities, close_to]
reason: Capability required to run test
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: { match_all: { } }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67]

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}

- match: {hits.hits.2._id: "1"}
- close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}

- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: { match_all: { } }
script:
source: "dotProduct(params.query_vector, 'indexed_vector')"
params:
query_vector: [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67]

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "2"}
- close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}

- match: {hits.hits.1._id: "3"}
- close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}

- match: {hits.hits.2._id: "1"}
- close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}
---
"Dot product with byte":
- requires:
capabilities:
- method: POST
path: /_search
capabilities: [ byte_float_bit_dot_product ]
test_runner_features: capabilities
reason: Capability required to run test
- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: { match_all: { } }
script:
source: "dotProduct(params.query_vector, 'vector')"
params:
query_vector: [12, -34, 56, -78, 90, 12, 34, -56, 78, -90, 23, -45, 67, -89, 12, 34, 56, 78, 90, -12, 34, -56, 78, -90, 23, -45, 67, -89, 12, -34, 56, -78, 90, -12, 34, -56, 78, 90, 23, -45]

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "1"}
- match: {hits.hits.0._score: 248}

- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1._score: 136}

- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 20}

- do:
headers:
Content-Type: application/json
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: { match_all: { } }
script:
source: "dotProduct(params.query_vector, 'indexed_vector')"
params:
query_vector: [12, -34, 56, -78, 90, 12, 34, -56, 78, -90, 23, -45, 67, -89, 12, 34, 56, 78, 90, -12, 34, -56, 78, -90, 23, -45, 67, -89, 12, -34, 56, -78, 90, -12, 34, -56, 78, 90, 23, -45]

- match: { hits.total: 3 }

- match: {hits.hits.0._id: "1"}
- match: {hits.hits.0._score: 248}

- match: {hits.hits.1._id: "2"}
- match: {hits.hits.1._score: 136}

- match: {hits.hits.2._id: "3"}
- match: {hits.hits.2._score: 20}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ private SearchCapabilities() {}
private static final String RANGE_REGEX_INTERVAL_QUERY_CAPABILITY = "range_regexp_interval_queries";
/** Support synthetic source with `bit` type in `dense_vector` field when `index` is set to `false`. */
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
/** Support Byte and Float with Bit dot product. */
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product";

public static final Set<String> CAPABILITIES = Set.of(
RANGE_REGEX_INTERVAL_QUERY_CAPABILITY,
BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY
BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY,
BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY
);
}
Loading