Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First round of optimizations for vector functions. #46294

Merged
merged 8 commits into from
Sep 3, 2019
2 changes: 1 addition & 1 deletion docs/reference/mapping/types/dense-vector.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ PUT my_index/_doc/2

Internally, each document's dense vector is encoded as a binary
doc value. Its size in bytes is equal to
`4 * dims`, where `dims`—the number of the vector's dimensions.
`4 * dims + 4`, where `dims`—the number of the vector's dimensions.
2 changes: 1 addition & 1 deletion docs/reference/mapping/types/sparse-vector.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ PUT my_index/_doc/2

Internally, each document's sparse vector is encoded as a binary
doc value. Its size in bytes is equal to
`6 * NUMBER_OF_DIMENSIONS`, where `NUMBER_OF_DIMENSIONS` -
`6 * NUMBER_OF_DIMENSIONS + 4`, where `NUMBER_OF_DIMENSIONS` -
number of the vector's dimensions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.script.ExplainableScoreScript;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.Version;

import java.io.IOException;
import java.util.Objects;
Expand Down Expand Up @@ -52,22 +53,24 @@ public float score() {

private final int shardId;
private final String indexName;

private final Version indexVersion;

public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script) {
super(CombineFunction.REPLACE);
this.sScript = sScript;
this.script = script;
this.indexName = null;
this.shardId = -1;
this.indexVersion = null;
}

public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId) {
public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) {
super(CombineFunction.REPLACE);
this.sScript = sScript;
this.script = script;
this.indexName = indexName;
this.shardId = shardId;
this.indexVersion = indexVersion;
}

@Override
Expand All @@ -77,6 +80,7 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx
leafScript.setScorer(scorer);
leafScript._setIndexName(indexName);
leafScript._setShard(shardId);
leafScript._setIndexVersion(indexVersion);
return new LeafScoreFunction() {
@Override
public double score(int docId, float subQueryScore) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ protected ScoreFunction doToFunction(QueryShardContext context) {
try {
ScoreScript.Factory factory = context.getScriptService().compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
return new ScriptScoreFunction(script, searchScript, context.index().getName(), context.getShardId());
return new ScriptScoreFunction(script, searchScript,
context.index().getName(), context.getShardId(), context.indexVersionCreated());
} catch (Exception e) {
throw new QueryShardException(context, "script_score: the script could not be loaded", e);
}
Expand Down
22 changes: 22 additions & 0 deletions server/src/main/java/org/elasticsearch/script/ScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.index.fielddata.ScriptDocValues;
import org.elasticsearch.search.lookup.LeafSearchLookup;
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.Version;

import java.io.IOException;
import java.io.UncheckedIOException;
Expand Down Expand Up @@ -56,6 +57,7 @@ public abstract class ScoreScript {
private int docId;
private int shardId = -1;
private String indexName = null;
private Version indexVersion = null;

public ScoreScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
// null check needed b/c of expression engine subclass
Expand Down Expand Up @@ -155,6 +157,19 @@ public String _getIndex() {
}
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
* It is only used within predefined painless functions.
* @return index version or throws an exception if the index version is not set up for this script instance
*/
public Version _getIndexVersion() {
if (indexVersion != null) {
return indexVersion;
} else {
throw new IllegalArgumentException("index version can not be looked up!");
}
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
*/
Expand All @@ -169,6 +184,13 @@ public void _setIndexName(String indexName) {
this.indexName = indexName;
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
*/
public void _setIndexVersion(Version indexVersion) {
this.indexVersion = indexVersion;
}


/** A factory to construct {@link ScoreScript} instances. */
public interface LeafFactory {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ setup:
index: test-index
id: 2
body:
my_dense_vector: [10.9, 10.9, 10.9]
my_dense_vector: [10.5, 10.9, 10.4]

- do:
indices.refresh: {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser.Token;
Expand All @@ -25,10 +26,11 @@
import org.elasticsearch.index.mapper.MapperParsingException;
import org.elasticsearch.index.mapper.ParseContext;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.ZoneId;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -180,8 +182,11 @@ public void parse(ParseContext context) throws IOException {

// encode array of floats as array of integers and store into buf
// this code is here and not int the VectorEncoderDecoder so not to create extra arrays
byte[] buf = new byte[dims * INT_BYTES];
int offset = 0;
byte[] bytes = indexCreatedVersion.onOrAfter(Version.V_7_5_0) ? new byte[dims * INT_BYTES + INT_BYTES] : new byte[dims * INT_BYTES];

ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
double dotProduct = 0f;

int dim = 0;
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
if (dim++ >= dims) {
Expand All @@ -190,18 +195,22 @@ public void parse(ParseContext context) throws IOException {
}
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()::getTokenLocation);
float value = context.parser().floatValue(true);
int intValue = Float.floatToIntBits(value);
buf[offset++] = (byte) (intValue >> 24);
buf[offset++] = (byte) (intValue >> 16);
buf[offset++] = (byte) (intValue >> 8);
buf[offset++] = (byte) intValue;

byteBuffer.putFloat(value);
dotProduct += value * value;
}
if (dim != dims) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" +
context.sourceToParse().id() + "] has number of dimensions [" + dim +
"] less than defined in the mapping [" + dims +"]");
}
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf, 0, offset));

if (indexCreatedVersion.onOrAfter(Version.V_7_5_0)) {
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
byteBuffer.putFloat(vectorMagnitude);
}
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(bytes));
if (context.doc().getByKey(fieldType().name()) != null) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() +
"] doesn't not support indexing multiple values for the same field in the same document");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ public void parse(ParseContext context) throws IOException {
}
}

BytesRef br = VectorEncoderDecoder.encodeSparseVector(dims, values, dimCount);
BytesRef br = VectorEncoderDecoder.encodeSparseVector(indexCreatedVersion, dims, values, dimCount);
BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), br);
context.doc().addWithKey(fieldType().name(), field);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.InPlaceMergeSorter;
import org.elasticsearch.Version;

import java.nio.ByteBuffer;

// static utility functions for encoding and decoding dense_vector and sparse_vector fields
public final class VectorEncoderDecoder {
Expand All @@ -19,81 +22,84 @@ private VectorEncoderDecoder() { }

/**
* Encodes a sparse array represented by values, dims and dimCount into a bytes array - BytesRef
* BytesRef: int[] floats encoded as integers values, 2 bytes for each dimension
* @param values - values of the sparse array
* BytesRef: int[] floats encoded as integers values, 2 bytes for each dimension, length of vector
* @param indexVersion - index version
* @param dims - dims of the sparse array
* @param values - values of the sparse array
* @param dimCount - number of the dimensions, necessary as values and dims are dynamically created arrays,
* and may be over-allocated
* @return BytesRef
*/
public static BytesRef encodeSparseVector(int[] dims, float[] values, int dimCount) {
public static BytesRef encodeSparseVector(Version indexVersion, int[] dims, float[] values, int dimCount) {
// 1. Sort dims and values
sortSparseDimsValues(dims, values, dimCount);
byte[] buf = new byte[dimCount * (INT_BYTES + SHORT_BYTES)];

// 2. Encode dimensions
// as each dimension is a positive value that doesn't exceed 65535, 2 bytes is enough for encoding it
int offset = 0;
byte[] bytes = indexVersion.onOrAfter(Version.V_7_5_0) ? new byte[dimCount * (INT_BYTES + SHORT_BYTES) + INT_BYTES] :
new byte[dimCount * (INT_BYTES + SHORT_BYTES)];
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);

for (int dim = 0; dim < dimCount; dim++) {
buf[offset] = (byte) (dims[dim] >> 8);
buf[offset+1] = (byte) dims[dim];
offset += SHORT_BYTES;
int dimValue = dims[dim];
byteBuffer.put((byte) (dimValue >> 8));
byteBuffer.put((byte) dimValue);
}

// 3. Encode values
double dotProduct = 0.0f;
for (int dim = 0; dim < dimCount; dim++) {
int intValue = Float.floatToIntBits(values[dim]);
buf[offset] = (byte) (intValue >> 24);
buf[offset+1] = (byte) (intValue >> 16);
buf[offset+2] = (byte) (intValue >> 8);
buf[offset+3] = (byte) intValue;
offset += INT_BYTES;
float value = values[dim];
byteBuffer.putFloat(value);
dotProduct += value * value;
}

return new BytesRef(buf);
// 4. Encode vector magnitude at the end
if (indexVersion.onOrAfter(Version.V_7_5_0)) {
float vectorMagnitude = (float) Math.sqrt(dotProduct);
byteBuffer.putFloat(vectorMagnitude);
}

return new BytesRef(bytes);
}

/**
* Decodes the first part of BytesRef into sparse vector dimensions
* @param indexVersion - index version
* @param vectorBR - sparse vector encoded in BytesRef
*/
public static int[] decodeSparseVectorDims(BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES);
public static int[] decodeSparseVectorDims(Version indexVersion, BytesRef vectorBR) {
int dimCount = indexVersion.onOrAfter(Version.V_7_5_0)
? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES)
: vectorBR.length / (INT_BYTES + SHORT_BYTES);
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, dimCount * SHORT_BYTES);

int[] dims = new int[dimCount];
int offset = vectorBR.offset;
for (int dim = 0; dim < dimCount; dim++) {
dims[dim] = ((vectorBR.bytes[offset] & 0xFF) << 8) | (vectorBR.bytes[offset+1] & 0xFF);
offset += SHORT_BYTES;
dims[dim] = ((byteBuffer.get() & 0xFF) << 8) | (byteBuffer.get() & 0xFF);
}
return dims;
}

/**
* Decodes the second part of the BytesRef into sparse vector values
* @param indexVersion - index version
* @param vectorBR - sparse vector encoded in BytesRef
*/
public static float[] decodeSparseVector(BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES);
int offset = vectorBR.offset + SHORT_BYTES * dimCount; //calculate the offset from where values are encoded
public static float[] decodeSparseVector(Version indexVersion, BytesRef vectorBR) {
int dimCount = indexVersion.onOrAfter(Version.V_7_5_0)
? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES)
: vectorBR.length / (INT_BYTES + SHORT_BYTES);
int offset = vectorBR.offset + SHORT_BYTES * dimCount;
float[] vector = new float[dimCount];

ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, offset, dimCount * INT_BYTES);
for (int dim = 0; dim < dimCount; dim++) {
int intValue = ((vectorBR.bytes[offset] & 0xFF) << 24) |
((vectorBR.bytes[offset+1] & 0xFF) << 16) |
((vectorBR.bytes[offset+2] & 0xFF) << 8) |
(vectorBR.bytes[offset+3] & 0xFF);
vector[dim] = Float.intBitsToFloat(intValue);
offset = offset + INT_BYTES;
vector[dim] = byteBuffer.getFloat();
}
return vector;
}


/**
* Sorts dimensions in the ascending order and
* sorts values in the same order as their corresponding dimensions
Expand Down Expand Up @@ -150,24 +156,20 @@ public void swap(int i, int j) {
}.sort(0, n);
}

public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) {
return indexVersion.onOrAfter(Version.V_7_5_0)
? (vectorBR.length - INT_BYTES) / INT_BYTES
: vectorBR.length / INT_BYTES;
}

/**
* Decodes a BytesRef into an array of floats
* @param vectorBR - dense vector encoded in BytesRef
* Decodes the last 4 bytes of the encoded vector, which contains the vector magnitude.
* NOTE: this function can only be called on vectors from an index version greater than or
* equal to 7.4.0, since vectors created prior to that do not store the magnitude.
*/
public static float[] decodeDenseVector(BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
int dimCount = vectorBR.length / INT_BYTES;
float[] vector = new float[dimCount];
int offset = vectorBR.offset;
for (int dim = 0; dim < dimCount; dim++) {
int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) |
((vectorBR.bytes[offset++] & 0xFF) << 16) |
((vectorBR.bytes[offset++] & 0xFF) << 8) |
(vectorBR.bytes[offset++] & 0xFF);
vector[dim] = Float.intBitsToFloat(intValue);
}
return vector;
public static float decodeVectorMagnitude(Version indexVersion, BytesRef vectorBR) {
assert indexVersion.onOrAfter(Version.V_7_5_0);
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4);
}
}
Loading