-
Notifications
You must be signed in to change notification settings - Fork 24.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for bitwise inner-product in painless #116082
Changes from 3 commits
5c2f974
a196c9b
2c33bd9
7ba48fa
d4b95be
f2d1660
c46da24
5519af8
808ed5e
5c23174
2a36775
14670f7
3732ae4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 116082 | ||
summary: Add support for bitwise inner-product in painless | ||
area: Vector Search | ||
type: enhancement | ||
issues: [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -332,6 +332,9 @@ When using `bit` vectors, not all the vector functions are available. The suppor | |
* <<vector-functions-hamming,`hamming`>> – calculates Hamming distance, the sum of the bitwise XOR of the two vectors | ||
* <<vector-functions-l1,`l1norm`>> – calculates L^1^ distance, this is simply the `hamming` distance | ||
* <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance, this is the square root of the `hamming` distance | ||
* <<vector-functions-dot-product,`dotProduct`>> – calculates dot product. When comparing two `bit` vectors, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May be we can add that queryVector can be byte[] (of the same dims as docs or dims *8), or also can be a string, and can be of float[] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ++ |
||
this is the sum of the bitwise AND of the two vectors. If providing `float[]` as a query vector, the `dotProduct` is | ||
the sum of the floating point values using the stored `bit` vector as a mask. | ||
|
||
Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors. | ||
Currently, the `cosineSimilarity` function is not supported for `bit` vectors. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,13 +9,36 @@ | |
|
||
package org.elasticsearch.simdvec; | ||
|
||
import org.apache.lucene.util.BitUtil; | ||
import org.apache.lucene.util.Constants; | ||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport; | ||
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; | ||
|
||
import java.lang.invoke.MethodHandle; | ||
import java.lang.invoke.MethodHandles; | ||
import java.lang.invoke.MethodType; | ||
|
||
import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY; | ||
|
||
public class ESVectorUtil { | ||
|
||
private static final MethodHandle BIT_COUNT_MH; | ||
static { | ||
try { | ||
// For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time. | ||
// On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when | ||
// compared to Integer::bitCount. While Long::bitCount is optimal on x64. See | ||
// https://bugs.openjdk.org/browse/JDK-8336000 | ||
BIT_COUNT_MH = Constants.OS_ARCH.equals("aarch64") | ||
? MethodHandles.lookup() | ||
.findStatic(ESVectorUtil.class, "andBitCountInt", MethodType.methodType(int.class, byte[].class, byte[].class)) | ||
: MethodHandles.lookup() | ||
.findStatic(ESVectorUtil.class, "andBitCountLong", MethodType.methodType(int.class, byte[].class, byte[].class)); | ||
} catch (NoSuchMethodException | IllegalAccessException e) { | ||
throw new AssertionError(e); | ||
} | ||
} | ||
|
||
private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport(); | ||
|
||
public static long ipByteBinByte(byte[] q, byte[] d) { | ||
|
@@ -24,4 +47,103 @@ public static long ipByteBinByte(byte[] q, byte[] d) { | |
} | ||
return IMPL.ipByteBinByte(q, d); | ||
} | ||
|
||
/** | ||
* Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector. | ||
* This will return the sum of the query vector values using the document vector as a mask. | ||
* @param q the query vector | ||
* @param d the document vector | ||
* @return the inner product of the two vectors | ||
*/ | ||
public static int ipByteBit(byte[] q, byte[] d) { | ||
if (q.length != d.length * Byte.SIZE) { | ||
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length); | ||
} | ||
int result = 0; | ||
// now combine the two vectors, summing the byte dimensions where the bit in d is `1` | ||
for (int i = 0; i < d.length; i++) { | ||
byte mask = d[i]; | ||
for (int j = 0; j < Byte.SIZE; j++) { | ||
if ((mask & (1 << j)) != 0) { | ||
result += q[i * Byte.SIZE + j]; | ||
} | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
/** | ||
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector. | ||
* This will return the sum of the query vector values using the document vector as a mask. | ||
* @param q the query vector | ||
* @param d the document vector | ||
* @return the inner product of the two vectors | ||
*/ | ||
public static float ipFloatBit(float[] q, byte[] d) { | ||
if (q.length != d.length * Byte.SIZE) { | ||
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length); | ||
} | ||
float result = 0; | ||
for (int i = 0; i < d.length; i++) { | ||
byte mask = d[i]; | ||
for (int j = 0; j < Byte.SIZE; j++) { | ||
if ((mask & (1 << j)) != 0) { | ||
result += q[i * Byte.SIZE + j]; | ||
} | ||
} | ||
} | ||
return result; | ||
} | ||
|
||
/** | ||
* AND bit count computed over signed bytes. | ||
* Copied from Lucene's XOR implementation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is more so for my education. What's the thinking here for putting this in ES vs Lucene? given that we have XOR in Lucene VectorUtil and this seem complementary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @john-wagster there is no compelling reason for keeping it out of Lucene. But, its weird for there to be public utility methods in Lucene when nothing directly utilizes it. |
||
* @param a bytes containing a vector | ||
* @param b bytes containing another vector, of the same dimension | ||
* @return the value of the AND bit count of the two vectors | ||
*/ | ||
public static int andBitCount(byte[] a, byte[] b) { | ||
if (a.length != b.length) { | ||
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); | ||
} | ||
try { | ||
return (int) BIT_COUNT_MH.invokeExact(a, b); | ||
} catch (Throwable e) { | ||
if (e instanceof Error err) { | ||
throw err; | ||
} else if (e instanceof RuntimeException re) { | ||
throw re; | ||
} else { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
} | ||
|
||
/** AND bit count striding over 4 bytes at a time. */ | ||
static int andBitCountInt(byte[] a, byte[] b) { | ||
int distance = 0, i = 0; | ||
// limit to number of int values in the array iterating by int byte views | ||
for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { | ||
distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) & (int) BitUtil.VH_NATIVE_INT.get(b, i)); | ||
} | ||
// tail: | ||
for (; i < a.length; i++) { | ||
distance += Integer.bitCount((a[i] & b[i]) & 0xFF); | ||
} | ||
return distance; | ||
} | ||
|
||
/** AND bit count striding over 8 bytes at a time**/ | ||
static int andBitCountLong(byte[] a, byte[] b) { | ||
int distance = 0, i = 0; | ||
// limit to number of long values in the array iterating by long byte views | ||
for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) { | ||
benwtrent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) & (long) BitUtil.VH_NATIVE_LONG.get(b, i)); | ||
} | ||
// tail: | ||
for (; i < a.length; i++) { | ||
distance += Integer.bitCount((a[i] & b[i]) & 0xFF); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could the tail be done with a single There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possibly? But I didn't want to bother with over optimizing. Especially since these methods are effectively copy-pastes of what exists in Lucene for xor (just changing to |
||
} | ||
return distance; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On line 19 we also say that
dot_product
is not supported for bit vectors.