Skip to content

Commit

Permalink
First round of optimizations for vector functions. (#46294)
Browse files Browse the repository at this point in the history
This PR merges the `vectors-optimize-brute-force` feature branch, which makes
the following changes to how vector functions are computed:
* Precompute the L2 norm of each vector at indexing time. (#45390)
* Switch to ByteBuffer for vector encoding. (#45936)
* Decode vectors and while computing the vector function. (#46103) 
* Use an array instead of a List for the query vector. (#46155)
* Precompute the normalized query vector when using cosine similarity. (#46190)

Co-authored-by: Mayya Sharipova <[email protected]>
  • Loading branch information
jtibshirani and mayya-sharipova committed Sep 4, 2019
1 parent d5ad86d commit 8e588db
Show file tree
Hide file tree
Showing 15 changed files with 560 additions and 299 deletions.
2 changes: 1 addition & 1 deletion docs/reference/mapping/types/dense-vector.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ PUT my_index/_doc/2

Internally, each document's dense vector is encoded as a binary
doc value. Its size in bytes is equal to
`4 * dims`, where `dims`—the number of the vector's dimensions.
`4 * dims + 4`, where `dims`—the number of the vector's dimensions.
2 changes: 1 addition & 1 deletion docs/reference/mapping/types/sparse-vector.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ PUT my_index/_doc/2

Internally, each document's sparse vector is encoded as a binary
doc value. Its size in bytes is equal to
`6 * NUMBER_OF_DIMENSIONS`, where `NUMBER_OF_DIMENSIONS` -
`6 * NUMBER_OF_DIMENSIONS + 4`, where `NUMBER_OF_DIMENSIONS` -
number of the vector's dimensions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.script.ExplainableScoreScript;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.Version;

import java.io.IOException;
import java.util.Objects;
Expand Down Expand Up @@ -52,22 +53,24 @@ public float score() {

private final int shardId;
private final String indexName;

private final Version indexVersion;

public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script) {
super(CombineFunction.REPLACE);
this.sScript = sScript;
this.script = script;
this.indexName = null;
this.shardId = -1;
this.indexVersion = null;
}

public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId) {
public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) {
super(CombineFunction.REPLACE);
this.sScript = sScript;
this.script = script;
this.indexName = indexName;
this.shardId = shardId;
this.indexVersion = indexVersion;
}

@Override
Expand All @@ -77,6 +80,7 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx
leafScript.setScorer(scorer);
leafScript._setIndexName(indexName);
leafScript._setShard(shardId);
leafScript._setIndexVersion(indexVersion);
return new LeafScoreFunction() {
@Override
public double score(int docId, float subQueryScore) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ protected ScoreFunction doToFunction(QueryShardContext context) {
try {
ScoreScript.Factory factory = context.getScriptService().compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
return new ScriptScoreFunction(script, searchScript, context.index().getName(), context.getShardId());
return new ScriptScoreFunction(script, searchScript,
context.index().getName(), context.getShardId(), context.indexVersionCreated());
} catch (Exception e) {
throw new QueryShardException(context, "script_score: the script could not be loaded", e);
}
Expand Down
22 changes: 22 additions & 0 deletions server/src/main/java/org/elasticsearch/script/ScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.search.lookup.LeafSearchLookup;
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.Version;

import java.io.IOException;
import java.io.UncheckedIOException;
Expand Down Expand Up @@ -66,6 +67,7 @@ public abstract class ScoreScript {
private int docId;
private int shardId = -1;
private String indexName = null;
private Version indexVersion = null;

public ScoreScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
// null check needed b/c of expression engine subclass
Expand Down Expand Up @@ -165,6 +167,19 @@ public String _getIndex() {
}
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
* It is only used within predefined painless functions.
* @return index version or throws an exception if the index version is not set up for this script instance
*/
public Version _getIndexVersion() {
if (indexVersion != null) {
return indexVersion;
} else {
throw new IllegalArgumentException("index version can not be looked up!");
}
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
*/
Expand All @@ -179,6 +194,13 @@ public void _setIndexName(String indexName) {
this.indexName = indexName;
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
*/
public void _setIndexVersion(Version indexVersion) {
this.indexVersion = indexVersion;
}


/** A factory to construct {@link ScoreScript} instances. */
public interface LeafFactory {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ setup:
index: test-index
id: 2
body:
my_dense_vector: [10.9, 10.9, 10.9]
my_dense_vector: [10.5, 10.9, 10.4]

- do:
indices.refresh: {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser.Token;
Expand All @@ -25,10 +26,11 @@
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.ParseContext;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.ZoneId;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -180,8 +182,11 @@ public void parse(ParseContext context) throws IOException {

// encode array of floats as array of integers and store into buf
// this code is here and not int the VectorEncoderDecoder so not to create extra arrays
byte[] buf = new byte[dims * INT_BYTES];
int offset = 0;
byte[] bytes = indexCreatedVersion.onOrAfter(Version.V_7_5_0) ? new byte[dims * INT_BYTES + INT_BYTES] : new byte[dims * INT_BYTES];

ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
double dotProduct = 0f;

int dim = 0;
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
if (dim++ >= dims) {
Expand All @@ -190,18 +195,22 @@ public void parse(ParseContext context) throws IOException {
}
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()::getTokenLocation);
float value = context.parser().floatValue(true);
int intValue = Float.floatToIntBits(value);
buf[offset++] = (byte) (intValue >> 24);
buf[offset++] = (byte) (intValue >> 16);
buf[offset++] = (byte) (intValue >> 8);
buf[offset++] = (byte) intValue;

byteBuffer.putFloat(value);
dotProduct += value * value;
}
if (dim != dims) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" +
context.sourceToParse().id() + "] has number of dimensions [" + dim +
"] less than defined in the mapping [" + dims +"]");
}
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf, 0, offset));

if (indexCreatedVersion.onOrAfter(Version.V_7_5_0)) {
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
byteBuffer.putFloat(vectorMagnitude);
}
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(bytes));
if (context.doc().getByKey(fieldType().name()) != null) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() +
"] doesn't not support indexing multiple values for the same field in the same document");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ public void parse(ParseContext context) throws IOException {
}
}

BytesRef br = VectorEncoderDecoder.encodeSparseVector(dims, values, dimCount);
BytesRef br = VectorEncoderDecoder.encodeSparseVector(indexCreatedVersion, dims, values, dimCount);
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), br);
context.doc().addWithKey(fieldType().name(), field);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.InPlaceMergeSorter;
import org.elasticsearch.Version;

import java.nio.ByteBuffer;

// static utility functions for encoding and decoding dense_vector and sparse_vector fields
public final class VectorEncoderDecoder {
Expand All @@ -19,81 +22,84 @@ private VectorEncoderDecoder() { }

/**
* Encodes a sparse array represented by values, dims and dimCount into a bytes array - BytesRef
* BytesRef: int[] floats encoded as integers values, 2 bytes for each dimension
* @param values - values of the sparse array
* BytesRef: int[] floats encoded as integers values, 2 bytes for each dimension, length of vector
* @param indexVersion - index version
* @param dims - dims of the sparse array
* @param values - values of the sparse array
* @param dimCount - number of the dimensions, necessary as values and dims are dynamically created arrays,
* and may be over-allocated
* @return BytesRef
*/
public static BytesRef encodeSparseVector(int[] dims, float[] values, int dimCount) {
public static BytesRef encodeSparseVector(Version indexVersion, int[] dims, float[] values, int dimCount) {
// 1. Sort dims and values
sortSparseDimsValues(dims, values, dimCount);
byte[] buf = new byte[dimCount * (INT_BYTES + SHORT_BYTES)];

// 2. Encode dimensions
// as each dimension is a positive value that doesn't exceed 65535, 2 bytes is enough for encoding it
int offset = 0;
byte[] bytes = indexVersion.onOrAfter(Version.V_7_5_0) ? new byte[dimCount * (INT_BYTES + SHORT_BYTES) + INT_BYTES] :
new byte[dimCount * (INT_BYTES + SHORT_BYTES)];
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);

for (int dim = 0; dim < dimCount; dim++) {
buf[offset] = (byte) (dims[dim] >> 8);
buf[offset+1] = (byte) dims[dim];
offset += SHORT_BYTES;
int dimValue = dims[dim];
byteBuffer.put((byte) (dimValue >> 8));
byteBuffer.put((byte) dimValue);
}

// 3. Encode values
double dotProduct = 0.0f;
for (int dim = 0; dim < dimCount; dim++) {
int intValue = Float.floatToIntBits(values[dim]);
buf[offset] = (byte) (intValue >> 24);
buf[offset+1] = (byte) (intValue >> 16);
buf[offset+2] = (byte) (intValue >> 8);
buf[offset+3] = (byte) intValue;
offset += INT_BYTES;
float value = values[dim];
byteBuffer.putFloat(value);
dotProduct += value * value;
}

return new BytesRef(buf);
// 4. Encode vector magnitude at the end
if (indexVersion.onOrAfter(Version.V_7_5_0)) {
float vectorMagnitude = (float) Math.sqrt(dotProduct);
byteBuffer.putFloat(vectorMagnitude);
}

return new BytesRef(bytes);
}

/**
* Decodes the first part of BytesRef into sparse vector dimensions
* @param indexVersion - index version
* @param vectorBR - sparse vector encoded in BytesRef
*/
public static int[] decodeSparseVectorDims(BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES);
public static int[] decodeSparseVectorDims(Version indexVersion, BytesRef vectorBR) {
int dimCount = indexVersion.onOrAfter(Version.V_7_5_0)
? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES)
: vectorBR.length / (INT_BYTES + SHORT_BYTES);
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, dimCount * SHORT_BYTES);

int[] dims = new int[dimCount];
int offset = vectorBR.offset;
for (int dim = 0; dim < dimCount; dim++) {
dims[dim] = ((vectorBR.bytes[offset] & 0xFF) << 8) | (vectorBR.bytes[offset+1] & 0xFF);
offset += SHORT_BYTES;
dims[dim] = ((byteBuffer.get() & 0xFF) << 8) | (byteBuffer.get() & 0xFF);
}
return dims;
}

/**
* Decodes the second part of the BytesRef into sparse vector values
* @param indexVersion - index version
* @param vectorBR - sparse vector encoded in BytesRef
*/
public static float[] decodeSparseVector(BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES);
int offset = vectorBR.offset + SHORT_BYTES * dimCount; //calculate the offset from where values are encoded
public static float[] decodeSparseVector(Version indexVersion, BytesRef vectorBR) {
int dimCount = indexVersion.onOrAfter(Version.V_7_5_0)
? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES)
: vectorBR.length / (INT_BYTES + SHORT_BYTES);
int offset = vectorBR.offset + SHORT_BYTES * dimCount;
float[] vector = new float[dimCount];

ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, offset, dimCount * INT_BYTES);
for (int dim = 0; dim < dimCount; dim++) {
int intValue = ((vectorBR.bytes[offset] & 0xFF) << 24) |
((vectorBR.bytes[offset+1] & 0xFF) << 16) |
((vectorBR.bytes[offset+2] & 0xFF) << 8) |
(vectorBR.bytes[offset+3] & 0xFF);
vector[dim] = Float.intBitsToFloat(intValue);
offset = offset + INT_BYTES;
vector[dim] = byteBuffer.getFloat();
}
return vector;
}


/**
* Sorts dimensions in the ascending order and
* sorts values in the same order as their corresponding dimensions
Expand Down Expand Up @@ -150,24 +156,20 @@ public void swap(int i, int j) {
}.sort(0, n);
}

public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) {
return indexVersion.onOrAfter(Version.V_7_5_0)
? (vectorBR.length - INT_BYTES) / INT_BYTES
: vectorBR.length / INT_BYTES;
}

/**
* Decodes a BytesRef into an array of floats
* @param vectorBR - dense vector encoded in BytesRef
* Decodes the last 4 bytes of the encoded vector, which contains the vector magnitude.
* NOTE: this function can only be called on vectors from an index version greater than or
* equal to 7.5.0, since vectors created prior to that do not store the magnitude.
*/
public static float[] decodeDenseVector(BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
int dimCount = vectorBR.length / INT_BYTES;
float[] vector = new float[dimCount];
int offset = vectorBR.offset;
for (int dim = 0; dim < dimCount; dim++) {
int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) |
((vectorBR.bytes[offset++] & 0xFF) << 16) |
((vectorBR.bytes[offset++] & 0xFF) << 8) |
(vectorBR.bytes[offset++] & 0xFF);
vector[dim] = Float.intBitsToFloat(intValue);
}
return vector;
public static float decodeVectorMagnitude(Version indexVersion, BytesRef vectorBR) {
assert indexVersion.onOrAfter(Version.V_7_5_0);
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4);
}
}
Loading

0 comments on commit 8e588db

Please sign in to comment.