From 8e588dbf6d08d02dae687ea7addf231bb6bfff74 Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Tue, 3 Sep 2019 16:30:12 -0700 Subject: [PATCH] First round of optimizations for vector functions. (#46294) This PR merges the `vectors-optimize-brute-force` feature branch, which makes the following changes to how vector functions are computed: * Precompute the L2 norm of each vector at indexing time. (#45390) * Switch to ByteBuffer for vector encoding. (#45936) * Decode vectors and while computing the vector function. (#46103) * Use an array instead of a List for the query vector. (#46155) * Precompute the normalized query vector when using cosine similarity. (#46190) Co-authored-by: Mayya Sharipova --- .../mapping/types/dense-vector.asciidoc | 2 +- .../mapping/types/sparse-vector.asciidoc | 2 +- .../search/function/ScriptScoreFunction.java | 8 +- .../ScriptScoreFunctionBuilder.java | 3 +- .../org/elasticsearch/script/ScoreScript.java | 22 ++ .../vectors/20_dense_vector_special_cases.yml | 2 +- .../mapper/DenseVectorFieldMapper.java | 27 +- .../mapper/SparseVectorFieldMapper.java | 2 +- .../vectors/mapper/VectorEncoderDecoder.java | 108 +++--- .../xpack/vectors/query/ScoreScriptUtils.java | 341 ++++++++++-------- .../xpack/vectors/query/whitelist.txt | 18 +- .../mapper/DenseVectorFieldMapperTests.java | 67 +++- .../mapper/SparseVectorFieldMapperTests.java | 81 ++++- .../mapper/VectorEncoderDecoderTests.java | 75 ++-- .../vectors/query/ScoreScriptUtilsTests.java | 101 ++++-- 15 files changed, 560 insertions(+), 299 deletions(-) diff --git a/docs/reference/mapping/types/dense-vector.asciidoc b/docs/reference/mapping/types/dense-vector.asciidoc index 9462fe544af9d..a1799fae71885 100644 --- a/docs/reference/mapping/types/dense-vector.asciidoc +++ b/docs/reference/mapping/types/dense-vector.asciidoc @@ -54,4 +54,4 @@ PUT my_index/_doc/2 Internally, each document's dense vector is encoded as a binary doc value. Its size in bytes is equal to -`4 * dims`, where `dims`—the number of the vector's dimensions. \ No newline at end of file +`4 * dims + 4`, where `dims`—the number of the vector's dimensions. \ No newline at end of file diff --git a/docs/reference/mapping/types/sparse-vector.asciidoc b/docs/reference/mapping/types/sparse-vector.asciidoc index 7b437031513b7..af3b6a510377d 100644 --- a/docs/reference/mapping/types/sparse-vector.asciidoc +++ b/docs/reference/mapping/types/sparse-vector.asciidoc @@ -56,5 +56,5 @@ PUT my_index/_doc/2 Internally, each document's sparse vector is encoded as a binary doc value. Its size in bytes is equal to -`6 * NUMBER_OF_DIMENSIONS`, where `NUMBER_OF_DIMENSIONS` - +`6 * NUMBER_OF_DIMENSIONS + 4`, where `NUMBER_OF_DIMENSIONS` - number of the vector's dimensions. \ No newline at end of file diff --git a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java index 960df44a62514..65dacd51e13c5 100644 --- a/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java +++ b/server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java @@ -25,6 +25,7 @@ import org.elasticsearch.script.ExplainableScoreScript; import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.Script; +import org.elasticsearch.Version; import java.io.IOException; import java.util.Objects; @@ -52,7 +53,7 @@ public float score() { private final int shardId; private final String indexName; - + private final Version indexVersion; public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script) { super(CombineFunction.REPLACE); @@ -60,14 +61,16 @@ public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script) { this.script = script; this.indexName = null; this.shardId = -1; + this.indexVersion = null; } - public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId) { + public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) { super(CombineFunction.REPLACE); this.sScript = sScript; this.script = script; this.indexName = indexName; this.shardId = shardId; + this.indexVersion = indexVersion; } @Override @@ -77,6 +80,7 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx leafScript.setScorer(scorer); leafScript._setIndexName(indexName); leafScript._setShard(shardId); + leafScript._setIndexVersion(indexVersion); return new LeafScoreFunction() { @Override public double score(int docId, float subQueryScore) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java b/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java index accfd2f656999..8fc2d4ff6b1a4 100644 --- a/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java @@ -94,7 +94,8 @@ protected ScoreFunction doToFunction(QueryShardContext context) { try { ScoreScript.Factory factory = context.getScriptService().compile(script, ScoreScript.CONTEXT); ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup()); - return new ScriptScoreFunction(script, searchScript, context.index().getName(), context.getShardId()); + return new ScriptScoreFunction(script, searchScript, + context.index().getName(), context.getShardId(), context.indexVersionCreated()); } catch (Exception e) { throw new QueryShardException(context, "script_score: the script could not be loaded", e); } diff --git a/server/src/main/java/org/elasticsearch/script/ScoreScript.java b/server/src/main/java/org/elasticsearch/script/ScoreScript.java index f31af4c008c74..5c58e761b66ff 100644 --- a/server/src/main/java/org/elasticsearch/script/ScoreScript.java +++ b/server/src/main/java/org/elasticsearch/script/ScoreScript.java @@ -23,6 +23,7 @@ import org.elasticsearch.index.fielddata.ScriptDocValues; import org.elasticsearch.search.lookup.LeafSearchLookup; import org.elasticsearch.search.lookup.SearchLookup; +import org.elasticsearch.Version; import java.io.IOException; import java.io.UncheckedIOException; @@ -66,6 +67,7 @@ public abstract class ScoreScript { private int docId; private int shardId = -1; private String indexName = null; + private Version indexVersion = null; public ScoreScript(Map params, SearchLookup lookup, LeafReaderContext leafContext) { // null check needed b/c of expression engine subclass @@ -165,6 +167,19 @@ public String _getIndex() { } } + /** + * Starting a name with underscore, so that the user cannot access this function directly through a script + * It is only used within predefined painless functions. + * @return index version or throws an exception if the index version is not set up for this script instance + */ + public Version _getIndexVersion() { + if (indexVersion != null) { + return indexVersion; + } else { + throw new IllegalArgumentException("index version can not be looked up!"); + } + } + /** * Starting a name with underscore, so that the user cannot access this function directly through a script */ @@ -179,6 +194,13 @@ public void _setIndexName(String indexName) { this.indexName = indexName; } + /** + * Starting a name with underscore, so that the user cannot access this function directly through a script + */ + public void _setIndexVersion(Version indexVersion) { + this.indexVersion = indexVersion; + } + /** A factory to construct {@link ScoreScript} instances. */ public interface LeafFactory { diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml index a5cc322426fa0..4d9394dc2b767 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/20_dense_vector_special_cases.yml @@ -46,7 +46,7 @@ setup: index: test-index id: 2 body: - my_dense_vector: [10.9, 10.9, 10.9] + my_dense_vector: [10.5, 10.9, 10.4] - do: indices.refresh: {} diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java index a7773e3e3c527..b1518d3ecd586 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java @@ -13,6 +13,7 @@ import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser.Token; @@ -25,10 +26,11 @@ import org.elasticsearch.index.mapper.MapperParsingException; import org.elasticsearch.index.mapper.ParseContext; import org.elasticsearch.index.query.QueryShardContext; -import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData; import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.xpack.vectors.query.VectorDVIndexFieldData; import java.io.IOException; +import java.nio.ByteBuffer; import java.time.ZoneId; import java.util.List; import java.util.Map; @@ -180,8 +182,11 @@ public void parse(ParseContext context) throws IOException { // encode array of floats as array of integers and store into buf // this code is here and not int the VectorEncoderDecoder so not to create extra arrays - byte[] buf = new byte[dims * INT_BYTES]; - int offset = 0; + byte[] bytes = indexCreatedVersion.onOrAfter(Version.V_7_5_0) ? new byte[dims * INT_BYTES + INT_BYTES] : new byte[dims * INT_BYTES]; + + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + double dotProduct = 0f; + int dim = 0; for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) { if (dim++ >= dims) { @@ -190,18 +195,22 @@ public void parse(ParseContext context) throws IOException { } ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser()::getTokenLocation); float value = context.parser().floatValue(true); - int intValue = Float.floatToIntBits(value); - buf[offset++] = (byte) (intValue >> 24); - buf[offset++] = (byte) (intValue >> 16); - buf[offset++] = (byte) (intValue >> 8); - buf[offset++] = (byte) intValue; + + byteBuffer.putFloat(value); + dotProduct += value * value; } if (dim != dims) { throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] of doc [" + context.sourceToParse().id() + "] has number of dimensions [" + dim + "] less than defined in the mapping [" + dims +"]"); } - BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(buf, 0, offset)); + + if (indexCreatedVersion.onOrAfter(Version.V_7_5_0)) { + // encode vector magnitude at the end + float vectorMagnitude = (float) Math.sqrt(dotProduct); + byteBuffer.putFloat(vectorMagnitude); + } + BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), new BytesRef(bytes)); if (context.doc().getByKey(fieldType().name()) != null) { throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't not support indexing multiple values for the same field in the same document"); diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapper.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapper.java index 3c551a4ee525f..38ea21922f44f 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapper.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapper.java @@ -181,7 +181,7 @@ public void parse(ParseContext context) throws IOException { } } - BytesRef br = VectorEncoderDecoder.encodeSparseVector(dims, values, dimCount); + BytesRef br = VectorEncoderDecoder.encodeSparseVector(indexCreatedVersion, dims, values, dimCount); BinaryDocValuesField field = new BinaryDocValuesField(fieldType().name(), br); context.doc().addWithKey(fieldType().name(), field); } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java index 67078b370ea98..2d591aaccd48f 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java @@ -9,6 +9,9 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.InPlaceMergeSorter; +import org.elasticsearch.Version; + +import java.nio.ByteBuffer; // static utility functions for encoding and decoding dense_vector and sparse_vector fields public final class VectorEncoderDecoder { @@ -19,81 +22,84 @@ private VectorEncoderDecoder() { } /** * Encodes a sparse array represented by values, dims and dimCount into a bytes array - BytesRef - * BytesRef: int[] floats encoded as integers values, 2 bytes for each dimension - * @param values - values of the sparse array + * BytesRef: int[] floats encoded as integers values, 2 bytes for each dimension, length of vector + * @param indexVersion - index version * @param dims - dims of the sparse array + * @param values - values of the sparse array * @param dimCount - number of the dimensions, necessary as values and dims are dynamically created arrays, * and may be over-allocated * @return BytesRef */ - public static BytesRef encodeSparseVector(int[] dims, float[] values, int dimCount) { + public static BytesRef encodeSparseVector(Version indexVersion, int[] dims, float[] values, int dimCount) { // 1. Sort dims and values sortSparseDimsValues(dims, values, dimCount); - byte[] buf = new byte[dimCount * (INT_BYTES + SHORT_BYTES)]; // 2. Encode dimensions // as each dimension is a positive value that doesn't exceed 65535, 2 bytes is enough for encoding it - int offset = 0; + byte[] bytes = indexVersion.onOrAfter(Version.V_7_5_0) ? new byte[dimCount * (INT_BYTES + SHORT_BYTES) + INT_BYTES] : + new byte[dimCount * (INT_BYTES + SHORT_BYTES)]; + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + for (int dim = 0; dim < dimCount; dim++) { - buf[offset] = (byte) (dims[dim] >> 8); - buf[offset+1] = (byte) dims[dim]; - offset += SHORT_BYTES; + int dimValue = dims[dim]; + byteBuffer.put((byte) (dimValue >> 8)); + byteBuffer.put((byte) dimValue); } // 3. Encode values + double dotProduct = 0.0f; for (int dim = 0; dim < dimCount; dim++) { - int intValue = Float.floatToIntBits(values[dim]); - buf[offset] = (byte) (intValue >> 24); - buf[offset+1] = (byte) (intValue >> 16); - buf[offset+2] = (byte) (intValue >> 8); - buf[offset+3] = (byte) intValue; - offset += INT_BYTES; + float value = values[dim]; + byteBuffer.putFloat(value); + dotProduct += value * value; } - return new BytesRef(buf); + // 4. Encode vector magnitude at the end + if (indexVersion.onOrAfter(Version.V_7_5_0)) { + float vectorMagnitude = (float) Math.sqrt(dotProduct); + byteBuffer.putFloat(vectorMagnitude); + } + + return new BytesRef(bytes); } /** * Decodes the first part of BytesRef into sparse vector dimensions + * @param indexVersion - index version * @param vectorBR - sparse vector encoded in BytesRef */ - public static int[] decodeSparseVectorDims(BytesRef vectorBR) { - if (vectorBR == null) { - throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); - } - int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES); + public static int[] decodeSparseVectorDims(Version indexVersion, BytesRef vectorBR) { + int dimCount = indexVersion.onOrAfter(Version.V_7_5_0) + ? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES) + : vectorBR.length / (INT_BYTES + SHORT_BYTES); + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, dimCount * SHORT_BYTES); + int[] dims = new int[dimCount]; - int offset = vectorBR.offset; for (int dim = 0; dim < dimCount; dim++) { - dims[dim] = ((vectorBR.bytes[offset] & 0xFF) << 8) | (vectorBR.bytes[offset+1] & 0xFF); - offset += SHORT_BYTES; + dims[dim] = ((byteBuffer.get() & 0xFF) << 8) | (byteBuffer.get() & 0xFF); } return dims; } /** * Decodes the second part of the BytesRef into sparse vector values + * @param indexVersion - index version * @param vectorBR - sparse vector encoded in BytesRef */ - public static float[] decodeSparseVector(BytesRef vectorBR) { - if (vectorBR == null) { - throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); - } - int dimCount = vectorBR.length / (INT_BYTES + SHORT_BYTES); - int offset = vectorBR.offset + SHORT_BYTES * dimCount; //calculate the offset from where values are encoded + public static float[] decodeSparseVector(Version indexVersion, BytesRef vectorBR) { + int dimCount = indexVersion.onOrAfter(Version.V_7_5_0) + ? (vectorBR.length - INT_BYTES) / (INT_BYTES + SHORT_BYTES) + : vectorBR.length / (INT_BYTES + SHORT_BYTES); + int offset = vectorBR.offset + SHORT_BYTES * dimCount; float[] vector = new float[dimCount]; + + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, offset, dimCount * INT_BYTES); for (int dim = 0; dim < dimCount; dim++) { - int intValue = ((vectorBR.bytes[offset] & 0xFF) << 24) | - ((vectorBR.bytes[offset+1] & 0xFF) << 16) | - ((vectorBR.bytes[offset+2] & 0xFF) << 8) | - (vectorBR.bytes[offset+3] & 0xFF); - vector[dim] = Float.intBitsToFloat(intValue); - offset = offset + INT_BYTES; + vector[dim] = byteBuffer.getFloat(); } return vector; } - /** * Sorts dimensions in the ascending order and * sorts values in the same order as their corresponding dimensions @@ -150,24 +156,20 @@ public void swap(int i, int j) { }.sort(0, n); } + public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) { + return indexVersion.onOrAfter(Version.V_7_5_0) + ? (vectorBR.length - INT_BYTES) / INT_BYTES + : vectorBR.length / INT_BYTES; + } + /** - * Decodes a BytesRef into an array of floats - * @param vectorBR - dense vector encoded in BytesRef + * Decodes the last 4 bytes of the encoded vector, which contains the vector magnitude. + * NOTE: this function can only be called on vectors from an index version greater than or + * equal to 7.5.0, since vectors created prior to that do not store the magnitude. */ - public static float[] decodeDenseVector(BytesRef vectorBR) { - if (vectorBR == null) { - throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); - } - int dimCount = vectorBR.length / INT_BYTES; - float[] vector = new float[dimCount]; - int offset = vectorBR.offset; - for (int dim = 0; dim < dimCount; dim++) { - int intValue = ((vectorBR.bytes[offset++] & 0xFF) << 24) | - ((vectorBR.bytes[offset++] & 0xFF) << 16) | - ((vectorBR.bytes[offset++] & 0xFF) << 8) | - (vectorBR.bytes[offset++] & 0xFF); - vector[dim] = Float.intBitsToFloat(intValue); - } - return vector; + public static float decodeVectorMagnitude(Version indexVersion, BytesRef vectorBR) { + assert indexVersion.onOrAfter(Version.V_7_5_0); + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); + return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4); } } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java index 9c54f267ca143..f286ab7328556 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java @@ -8,9 +8,11 @@ package org.elasticsearch.xpack.vectors.query; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; +import org.elasticsearch.script.ScoreScript; import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; -import java.util.Iterator; +import java.nio.ByteBuffer; import java.util.List; import java.util.Map; @@ -19,132 +21,166 @@ public class ScoreScriptUtils { //**************FUNCTIONS FOR DENSE VECTORS + // Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings. + // Also, constructors for some functions accept queryVector to calculate and cache queryVectorMagnitude only once + // per script execution for all documents. - /** - * Calculate l1 norm - Manhattan distance - * between a query's dense vector and documents' dense vectors - * - * @param queryVector the query vector parsed as {@code List} from json - * @param dvs VectorScriptDocValues representing encoded documents' vectors - */ - public static double l1norm(List queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs){ - BytesRef value = dvs.getEncodedValue(); - float[] docVector = VectorEncoderDecoder.decodeDenseVector(value); - if (queryVector.size() != docVector.length) { - throw new IllegalArgumentException("Can't calculate l1norm! The number of dimensions of the query vector [" + - queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "]."); + public static class DenseVectorFunction { + final ScoreScript scoreScript; + final float[] queryVector; + + public DenseVectorFunction(ScoreScript scoreScript, List queryVector) { + this(scoreScript, queryVector, false); + } + + /** + * Constructs a dense vector function. + * + * @param scoreScript The script in which this function was referenced. + * @param queryVector The query vector. + * @param normalizeQuery Whether the provided query should be normalized to unit length. + */ + public DenseVectorFunction(ScoreScript scoreScript, + List queryVector, + boolean normalizeQuery) { + this.scoreScript = scoreScript; + + this.queryVector = new float[queryVector.size()]; + double queryMagnitude = 0.0; + for (int i = 0; i < queryVector.size(); i++) { + float value = queryVector.get(i).floatValue(); + this.queryVector[i] = value; + queryMagnitude += value * value; + } + queryMagnitude = Math.sqrt(queryMagnitude); + + if (normalizeQuery) { + for (int dim = 0; dim < this.queryVector.length; dim++) { + this.queryVector[dim] /= queryMagnitude; + } + } } - Iterator queryVectorIter = queryVector.iterator(); - double l1norm = 0; - for (int dim = 0; dim < docVector.length; dim++){ - l1norm += Math.abs(queryVectorIter.next().floatValue() - docVector[dim]); + + public void validateDocVector(BytesRef vector) { + if (vector == null) { + throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); + } + + int vectorLength = VectorEncoderDecoder.denseVectorLength(scoreScript._getIndexVersion(), vector); + if (queryVector.length != vectorLength) { + throw new IllegalArgumentException("The query vector has a different number of dimensions [" + + queryVector.length + "] than the document vectors [" + vectorLength + "]."); + } } - return l1norm; } - /** - * Calculate l2 norm - Euclidean distance - * between a query's dense vector and documents' dense vectors - * - * @param queryVector the query vector parsed as {@code List} from json - * @param dvs VectorScriptDocValues representing encoded documents' vectors - */ - public static double l2norm(List queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs){ - BytesRef value = dvs.getEncodedValue(); - float[] docVector = VectorEncoderDecoder.decodeDenseVector(value); - if (queryVector.size() != docVector.length) { - throw new IllegalArgumentException("Can't calculate l2norm! The number of dimensions of the query vector [" + - queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "]."); + // Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors + public static final class L1Norm extends DenseVectorFunction { + + public L1Norm(ScoreScript scoreScript, List queryVector) { + super(scoreScript, queryVector); } - Iterator queryVectorIter = queryVector.iterator(); - double l2norm = 0; - for (int dim = 0; dim < docVector.length; dim++){ - double diff = queryVectorIter.next().floatValue() - docVector[dim]; - l2norm += diff * diff; + + public double l1norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { + BytesRef vector = dvs.getEncodedValue(); + validateDocVector(vector); + ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); + double l1norm = 0; + + for (float queryValue : queryVector) { + l1norm += Math.abs(queryValue - byteBuffer.getFloat()); + } + return l1norm; } - return Math.sqrt(l2norm); } + // Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors + public static final class L2Norm extends DenseVectorFunction { - /** - * Calculate a dot product between a query's dense vector and documents' dense vectors - * - * @param queryVector the query vector parsed as {@code List} from json - * @param dvs VectorScriptDocValues representing encoded documents' vectors - */ - public static double dotProduct(List queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs){ - BytesRef value = dvs.getEncodedValue(); - float[] docVector = VectorEncoderDecoder.decodeDenseVector(value); - if (queryVector.size() != docVector.length) { - throw new IllegalArgumentException("Can't calculate dotProduct! The number of dimensions of the query vector [" + - queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "]."); + public L2Norm(ScoreScript scoreScript, List queryVector) { + super(scoreScript, queryVector); + } + + public double l2norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { + BytesRef vector = dvs.getEncodedValue(); + validateDocVector(vector); + ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); + + double l2norm = 0; + for (float queryValue : queryVector) { + double diff = queryValue - byteBuffer.getFloat(); + l2norm += diff * diff; + } + return Math.sqrt(l2norm); } - return intDotProduct(queryVector, docVector); } - /** - * Calculate cosine similarity between a query's dense vector and documents' dense vectors - * - * CosineSimilarity is implemented as a class to use - * painless script caching to calculate queryVectorMagnitude - * only once per script execution for all documents. - * A user will call `cosineSimilarity(params.queryVector, doc['my_vector'])` - */ - public static final class CosineSimilarity { - final double queryVectorMagnitude; - final List queryVector; + // Calculate a dot product between a query's dense vector and documents' dense vectors + public static final class DotProduct extends DenseVectorFunction { - // calculate queryVectorMagnitude once per query execution - public CosineSimilarity(List queryVector) { - this.queryVector = queryVector; + public DotProduct(ScoreScript scoreScript, List queryVector) { + super(scoreScript, queryVector); + } + + public double dotProduct(VectorScriptDocValues.DenseVectorScriptDocValues dvs){ + BytesRef vector = dvs.getEncodedValue(); + validateDocVector(vector); + ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); double dotProduct = 0; - for (Number value : queryVector) { - float floatValue = value.floatValue(); - dotProduct += floatValue * floatValue; + for (float queryValue : queryVector) { + dotProduct += queryValue * byteBuffer.getFloat(); } - this.queryVectorMagnitude = Math.sqrt(dotProduct); + return dotProduct; } + } - public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { - BytesRef value = dvs.getEncodedValue(); - float[] docVector = VectorEncoderDecoder.decodeDenseVector(value); - if (queryVector.size() != docVector.length) { - throw new IllegalArgumentException("Can't calculate cosineSimilarity! The number of dimensions of the query vector [" + - queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "]."); - } - - // calculate docVector magnitude - double dotProduct = 0f; - for (int dim = 0; dim < docVector.length; dim++) { - dotProduct += (double) docVector[dim] * docVector[dim]; - } - final double docVectorMagnitude = Math.sqrt(dotProduct); + // Calculate cosine similarity between a query's dense vector and documents' dense vectors + public static final class CosineSimilarity extends DenseVectorFunction { - double docQueryDotProduct = intDotProduct(queryVector, docVector); - return docQueryDotProduct / (docVectorMagnitude * queryVectorMagnitude); + public CosineSimilarity(ScoreScript scoreScript, List queryVector) { + super(scoreScript, queryVector, true); } - } - private static double intDotProduct(List v1, float[] v2){ - double v1v2DotProduct = 0; - Iterator v1Iter = v1.iterator(); - for (int dim = 0; dim < v2.length; dim++) { - v1v2DotProduct += v1Iter.next().floatValue() * v2[dim]; + public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { + BytesRef vector = dvs.getEncodedValue(); + validateDocVector(vector); + + ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); + + double dotProduct = 0.0; + double vectorMagnitude = 0.0f; + if (scoreScript._getIndexVersion().onOrAfter(Version.V_7_5_0)) { + for (float queryValue : queryVector) { + dotProduct += queryValue * byteBuffer.getFloat(); + } + vectorMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(scoreScript._getIndexVersion(), vector); + } else { + for (float queryValue : queryVector) { + float docValue = byteBuffer.getFloat(); + dotProduct += queryValue * docValue; + vectorMagnitude += docValue * docValue; + } + vectorMagnitude = (float) Math.sqrt(vectorMagnitude); + } + return dotProduct / vectorMagnitude; } - return v1v2DotProduct; } - //**************FUNCTIONS FOR SPARSE VECTORS + // Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings. + // Also, constructors for some functions accept queryVector to calculate and cache queryVectorMagnitude only once + // per script execution for all documents. - public static class VectorSparseFunctions { + public static class SparseVectorFunction { + final ScoreScript scoreScript; final float[] queryValues; final int[] queryDims; // prepare queryVector once per script execution // queryVector represents a map of dimensions to values - public VectorSparseFunctions(Map queryVector) { + public SparseVectorFunction(ScoreScript scoreScript, Map queryVector) { + this.scoreScript = scoreScript; //break vector into two arrays dims and values int n = queryVector.size(); queryValues = new float[n]; @@ -162,26 +198,26 @@ public VectorSparseFunctions(Map queryVector) { // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions sortSparseDimsFloatValues(queryDims, queryValues, n); } + + public void validateDocVector(BytesRef vector) { + if (vector == null) { + throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); + } + } } - /** - * Calculate l1 norm - Manhattan distance - * between a query's sparse vector and documents' sparse vectors - * - * L1NormSparse is implemented as a class to use - * painless script caching to prepare queryVector - * only once per script execution for all documents. - * A user will call `l1normSparse(params.queryVector, doc['my_vector'])` - */ - public static final class L1NormSparse extends VectorSparseFunctions { - public L1NormSparse(Map queryVector) { - super(queryVector); + // Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors + public static final class L1NormSparse extends SparseVectorFunction { + public L1NormSparse(ScoreScript scoreScript,Map queryVector) { + super(scoreScript, queryVector); } public double l1normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef value = dvs.getEncodedValue(); - int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value); - float[] docValues = VectorEncoderDecoder.decodeSparseVector(value); + BytesRef vector = dvs.getEncodedValue(); + validateDocVector(vector); + + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); + float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); int queryIndex = 0; int docIndex = 0; double l1norm = 0; @@ -210,24 +246,18 @@ public double l1normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs } } - /** - * Calculate l2 norm - Euclidean distance - * between a query's sparse vector and documents' sparse vectors - * - * L2NormSparse is implemented as a class to use - * painless script caching to prepare queryVector - * only once per script execution for all documents. - * A user will call `l2normSparse(params.queryVector, doc['my_vector'])` - */ - public static final class L2NormSparse extends VectorSparseFunctions { - public L2NormSparse(Map queryVector) { - super(queryVector); + // Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors + public static final class L2NormSparse extends SparseVectorFunction { + public L2NormSparse(ScoreScript scoreScript, Map queryVector) { + super(scoreScript, queryVector); } public double l2normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef value = dvs.getEncodedValue(); - int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value); - float[] docValues = VectorEncoderDecoder.decodeSparseVector(value); + BytesRef vector = dvs.getEncodedValue(); + validateDocVector(vector); + + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); + float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); int queryIndex = 0; int docIndex = 0; double l2norm = 0; @@ -259,40 +289,28 @@ public double l2normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs } } - /** - * Calculate a dot product between a query's sparse vector and documents' sparse vectors - * - * DotProductSparse is implemented as a class to use - * painless script caching to prepare queryVector - * only once per script execution for all documents. - * A user will call `dotProductSparse(params.queryVector, doc['my_vector'])` - */ - public static final class DotProductSparse extends VectorSparseFunctions { - public DotProductSparse(Map queryVector) { - super(queryVector); + // Calculate a dot product between a query's sparse vector and documents' sparse vectors + public static final class DotProductSparse extends SparseVectorFunction { + public DotProductSparse(ScoreScript scoreScript, Map queryVector) { + super(scoreScript, queryVector); } public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef value = dvs.getEncodedValue(); - int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value); - float[] docValues = VectorEncoderDecoder.decodeSparseVector(value); + BytesRef vector = dvs.getEncodedValue(); + validateDocVector(vector); + + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); + float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); return intDotProductSparse(queryValues, queryDims, docValues, docDims); } } - /** - * Calculate cosine similarity between a query's sparse vector and documents' sparse vectors - * - * CosineSimilaritySparse is implemented as a class to use - * painless script caching to prepare queryVector and calculate queryVectorMagnitude - * only once per script execution for all documents. - * A user will call `cosineSimilaritySparse(params.queryVector, doc['my_vector'])` - */ - public static final class CosineSimilaritySparse extends VectorSparseFunctions { + // Calculate cosine similarity between a query's sparse vector and documents' sparse vectors + public static final class CosineSimilaritySparse extends SparseVectorFunction { final double queryVectorMagnitude; - public CosineSimilaritySparse(Map queryVector) { - super(queryVector); + public CosineSimilaritySparse(ScoreScript scoreScript, Map queryVector) { + super(scoreScript, queryVector); double dotProduct = 0; for (int i = 0; i< queryDims.length; i++) { dotProduct += queryValues[i] * queryValues[i]; @@ -301,18 +319,23 @@ public CosineSimilaritySparse(Map queryVector) { } public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { - BytesRef value = dvs.getEncodedValue(); - int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(value); - float[] docValues = VectorEncoderDecoder.decodeSparseVector(value); + BytesRef vector = dvs.getEncodedValue(); + validateDocVector(vector); - // calculate docVector magnitude - double dotProduct = 0; - for (float docValue : docValues) { - dotProduct += (double) docValue * docValue; - } - final double docVectorMagnitude = Math.sqrt(dotProduct); + int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); + float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); double docQueryDotProduct = intDotProductSparse(queryValues, queryDims, docValues, docDims); + double docVectorMagnitude = 0.0f; + if (scoreScript._getIndexVersion().onOrAfter(Version.V_7_5_0)) { + docVectorMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(scoreScript._getIndexVersion(), vector); + } else { + for (float docValue : docValues) { + docVectorMagnitude += docValue * docValue; + } + docVectorMagnitude = (float) Math.sqrt(docVectorMagnitude); + } + return docQueryDotProduct / (docVectorMagnitude * queryVectorMagnitude); } } diff --git a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt index 252d4356f9ca1..42d6e6d0b0f7a 100644 --- a/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt +++ b/x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt @@ -9,14 +9,16 @@ class org.elasticsearch.xpack.vectors.query.VectorScriptDocValues$DenseVectorScr } class org.elasticsearch.xpack.vectors.query.VectorScriptDocValues$SparseVectorScriptDocValues { } +class org.elasticsearch.script.ScoreScript @no_import { +} static_import { - double l1norm(List, VectorScriptDocValues.DenseVectorScriptDocValues) from_class org.elasticsearch.xpack.vectors.query.ScoreScriptUtils - double l2norm(List, VectorScriptDocValues.DenseVectorScriptDocValues) from_class org.elasticsearch.xpack.vectors.query.ScoreScriptUtils - double cosineSimilarity(List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity - double dotProduct(List, VectorScriptDocValues.DenseVectorScriptDocValues) from_class org.elasticsearch.xpack.vectors.query.ScoreScriptUtils - double l1normSparse(Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse - double l2normSparse(Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse - double dotProductSparse(Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse - double cosineSimilaritySparse(Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse + double l1norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm + double l2norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm + double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity + double dotProduct(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct + double l1normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse + double l2normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse + double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse + double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse } \ No newline at end of file diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java index d1b37c73a246e..52ef487935b68 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java @@ -10,9 +10,12 @@ import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.index.IndexableField; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; +import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.compress.CompressedXContent; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.IndexService; @@ -27,6 +30,7 @@ import org.elasticsearch.xpack.vectors.Vectors; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Collection; import static org.hamcrest.Matchers.containsString; @@ -38,6 +42,12 @@ protected Collection> getPlugins() { return pluginList(Vectors.class, XPackPlugin.class); } + // this allows to set indexVersion as it is a private setting + @Override + protected boolean forbidPrivateIndexSettings() { + return false; + } + public void testMappingExceedDimsLimit() throws IOException { IndexService indexService = createIndex("test-index"); DocumentMapperParser parser = indexService.mapperService().documentMapperParser(); @@ -55,6 +65,7 @@ public void testMappingExceedDimsLimit() throws IOException { } public void testDefaults() throws Exception { + Version indexVersion = Version.CURRENT; IndexService indexService = createIndex("test-index"); DocumentMapperParser parser = indexService.mapperService().documentMapperParser(); String mapping = Strings.toString(XContentFactory.jsonBuilder() @@ -69,6 +80,11 @@ public void testDefaults() throws Exception { DocumentMapper mapper = parser.parse("_doc", new CompressedXContent(mapping)); float[] validVector = {-12.1f, 100.7f, -4}; + double dotProduct = 0.0f; + for (float value: validVector) { + dotProduct += value * value; + } + float expectedMagnitude = (float) Math.sqrt(dotProduct); ParsedDocument doc1 = mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference .bytes(XContentFactory.jsonBuilder() .startObject() @@ -80,7 +96,9 @@ public void testDefaults() throws Exception { assertThat(fields[0], instanceOf(BinaryDocValuesField.class)); // assert that after decoding the indexed value is equal to expected BytesRef vectorBR = fields[0].binaryValue(); - float[] decodedValues = VectorEncoderDecoder.decodeDenseVector(vectorBR); + float[] decodedValues = decodeDenseVector(indexVersion, vectorBR); + float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(indexVersion, vectorBR); + assertEquals(expectedMagnitude, decodedMagnitude, 0.001f); assertArrayEquals( "Decoded dense vector values is not equal to the indexed one.", validVector, @@ -89,6 +107,53 @@ public void testDefaults() throws Exception { ); } + public void testAddDocumentsToIndexBefore_V_7_5_0() throws Exception { + Version indexVersion = Version.V_7_4_0; + IndexService indexService = createIndex("test-index7_4", + Settings.builder().put(IndexMetaData.SETTING_INDEX_VERSION_CREATED.getKey(), indexVersion).build()); + DocumentMapperParser parser = indexService.mapperService().documentMapperParser(); + String mapping = Strings.toString(XContentFactory.jsonBuilder() + .startObject() + .startObject("_doc") + .startObject("properties") + .startObject("my-dense-vector").field("type", "dense_vector").field("dims", 3) + .endObject() + .endObject() + .endObject() + .endObject()); + DocumentMapper mapper = parser.parse("_doc", new CompressedXContent(mapping)); + float[] validVector = {-12.1f, 100.7f, -4}; + ParsedDocument doc1 = mapper.parse(new SourceToParse("test-index7_4", "_doc", "1", BytesReference + .bytes(XContentFactory.jsonBuilder() + .startObject() + .startArray("my-dense-vector").value(validVector[0]).value(validVector[1]).value(validVector[2]).endArray() + .endObject()), + XContentType.JSON)); + IndexableField[] fields = doc1.rootDoc().getFields("my-dense-vector"); + assertEquals(1, fields.length); + assertThat(fields[0], instanceOf(BinaryDocValuesField.class)); + // assert that after decoding the indexed value is equal to expected + BytesRef vectorBR = fields[0].binaryValue(); + float[] decodedValues = decodeDenseVector(indexVersion, vectorBR); + assertArrayEquals( + "Decoded dense vector values is not equal to the indexed one.", + validVector, + decodedValues, + 0.001f + ); + } + + private static float[] decodeDenseVector(Version indexVersion, BytesRef encodedVector) { + int dimCount = VectorEncoderDecoder.denseVectorLength(indexVersion, encodedVector); + float[] vector = new float[dimCount]; + + ByteBuffer byteBuffer = ByteBuffer.wrap(encodedVector.bytes, encodedVector.offset, encodedVector.length); + for (int dim = 0; dim < dimCount; dim++) { + vector[dim] = byteBuffer.getFloat(); + } + return vector; + } + public void testDocumentsWithIncorrectDims() throws Exception { IndexService indexService = createIndex("test-index"); int dims = 3; diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapperTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapperTests.java index e1e110a750b93..915908ade4282 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapperTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapperTests.java @@ -10,9 +10,12 @@ import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.index.IndexableField; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; +import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.compress.CompressedXContent; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.IndexService; @@ -43,7 +46,7 @@ public class SparseVectorFieldMapperTests extends ESSingleNodeTestCase { @Before public void setUpMapper() throws Exception { - IndexService indexService = createIndex("test-index"); + IndexService indexService = createIndex("test-index"); DocumentMapperParser parser = indexService.mapperService().documentMapperParser(); String mapping = Strings.toString(XContentFactory.jsonBuilder() .startObject() @@ -62,7 +65,14 @@ protected Collection> getPlugins() { return pluginList(Vectors.class, XPackPlugin.class); } + // this allows to set indexVersion as it is a private setting + @Override + protected boolean forbidPrivateIndexSettings() { + return false; + } + public void testDefaults() throws Exception { + Version indexVersion = Version.CURRENT; int[] indexedDims = {65535, 50, 2}; float[] indexedValues = {0.5f, 1800f, -34567.11f}; ParsedDocument doc1 = mapper.parse(new SourceToParse("test-index", "_doc", "1", BytesReference @@ -79,19 +89,79 @@ public void testDefaults() throws Exception { assertEquals(1, fields.length); assertThat(fields[0], Matchers.instanceOf(BinaryDocValuesField.class)); + // assert that after decoding the indexed values are equal to expected + int[] expectedDims = {2, 50, 65535}; //the same as indexed but sorted + float[] expectedValues = {-34567.11f, 1800f, 0.5f}; //the same as indexed but sorted by their dimensions + double dotProduct = 0.0f; + for (float value: expectedValues) { + dotProduct += value * value; + } + float expectedMagnitude = (float) Math.sqrt(dotProduct); + + // assert that after decoded magnitude, dims and values are equal to expected + BytesRef vectorBR = fields[0].binaryValue(); + int[] decodedDims = VectorEncoderDecoder.decodeSparseVectorDims(indexVersion, vectorBR); + assertArrayEquals( + "Decoded sparse vector dimensions are not equal to the indexed ones.", + expectedDims, + decodedDims + ); + float[] decodedValues = VectorEncoderDecoder.decodeSparseVector(indexVersion, vectorBR); + assertArrayEquals( + "Decoded sparse vector values are not equal to the indexed ones.", + expectedValues, + decodedValues, + 0.001f + ); + float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(indexVersion, vectorBR); + assertEquals(expectedMagnitude, decodedMagnitude, 0.001f); + } + + public void testAddDocumentsToIndexBefore_V_7_5_0() throws Exception { + Version indexVersion = Version.V_7_4_0; + IndexService indexService = createIndex("test-index7_4", + Settings.builder().put(IndexMetaData.SETTING_INDEX_VERSION_CREATED.getKey(), indexVersion).build()); + DocumentMapperParser parser = indexService.mapperService().documentMapperParser(); + String mapping = Strings.toString(XContentFactory.jsonBuilder() + .startObject() + .startObject("_doc") + .startObject("properties") + .startObject("my-sparse-vector").field("type", "sparse_vector") + .endObject() + .endObject() + .endObject() + .endObject()); + mapper = parser.parse("_doc", new CompressedXContent(mapping)); + + int[] indexedDims = {65535, 50, 2}; + float[] indexedValues = {0.5f, 1800f, -34567.11f}; + ParsedDocument doc1 = mapper.parse(new SourceToParse("test-index7_4", "_doc", "1", BytesReference + .bytes(XContentFactory.jsonBuilder() + .startObject() + .startObject("my-sparse-vector") + .field(Integer.toString(indexedDims[0]), indexedValues[0]) + .field(Integer.toString(indexedDims[1]), indexedValues[1]) + .field(Integer.toString(indexedDims[2]), indexedValues[2]) + .endObject() + .endObject()), + XContentType.JSON)); + IndexableField[] fields = doc1.rootDoc().getFields("my-sparse-vector"); + assertEquals(1, fields.length); + assertThat(fields[0], Matchers.instanceOf(BinaryDocValuesField.class)); + // assert that after decoding the indexed values are equal to expected int[] expectedDims = {2, 50, 65535}; //the same as indexed but sorted float[] expectedValues = {-34567.11f, 1800f, 0.5f}; //the same as indexed but sorted by their dimensions - // assert that after decoding the indexed dims and values are equal to expected - BytesRef vectorBR = ((BinaryDocValuesField) fields[0]).binaryValue(); - int[] decodedDims = VectorEncoderDecoder.decodeSparseVectorDims(vectorBR); + // assert that after decoded magnitude, dims and values are equal to expected + BytesRef vectorBR = fields[0].binaryValue(); + int[] decodedDims = VectorEncoderDecoder.decodeSparseVectorDims(indexVersion, vectorBR); assertArrayEquals( "Decoded sparse vector dimensions are not equal to the indexed ones.", expectedDims, decodedDims ); - float[] decodedValues = VectorEncoderDecoder.decodeSparseVector(vectorBR); + float[] decodedValues = VectorEncoderDecoder.decodeSparseVector(indexVersion, vectorBR); assertArrayEquals( "Decoded sparse vector values are not equal to the indexed ones.", expectedValues, @@ -185,4 +255,5 @@ public void testDimensionLimit() throws IOException { new SourceToParse("test-index", "_doc", "1", invalidDoc, XContentType.JSON))); assertThat(e.getDetailedMessage(), containsString("has exceeded the maximum allowed number of dimensions")); } + } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java index 939d999b0d9aa..c81bdfe147ebd 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java @@ -7,33 +7,59 @@ package org.elasticsearch.xpack.vectors.mapper; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; import org.elasticsearch.test.ESTestCase; +import java.nio.ByteBuffer; import java.util.HashSet; import java.util.Set; import java.util.Arrays; public class VectorEncoderDecoderTests extends ESTestCase { - public void testDenseVectorEncodingDecoding() { - int dimCount = randomIntBetween(0, DenseVectorFieldMapper.MAX_DIMS_COUNT); + public void testSparseVectorEncodingDecoding() { + Version indexVersion = Version.CURRENT; + int dimCount = randomIntBetween(0, 100); float[] expectedValues = new float[dimCount]; + int[] expectedDims = randomUniqueDims(dimCount); + double dotProduct = 0.0f; for (int i = 0; i < dimCount; i++) { expectedValues[i] = randomFloat(); + dotProduct += expectedValues[i] * expectedValues[i]; } + float expectedMagnitude = (float) Math.sqrt(dotProduct); + + // test that sorting in the encoding works as expected + int[] sortedDims = Arrays.copyOf(expectedDims, dimCount); + Arrays.sort(sortedDims); + VectorEncoderDecoder.sortSparseDimsValues(expectedDims, expectedValues, dimCount); + assertArrayEquals( + "Sparse vector dims are not properly sorted!", + sortedDims, + expectedDims + ); // test that values that went through encoding and decoding are equal to their original - BytesRef encodedDenseVector = mockEncodeDenseVector(expectedValues); - float[] decodedValues = VectorEncoderDecoder.decodeDenseVector(encodedDenseVector); + BytesRef encodedSparseVector = VectorEncoderDecoder.encodeSparseVector(indexVersion, expectedDims, expectedValues, dimCount); + int[] decodedDims = VectorEncoderDecoder.decodeSparseVectorDims(indexVersion, encodedSparseVector); + float[] decodedValues = VectorEncoderDecoder.decodeSparseVector(indexVersion, encodedSparseVector); + float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(indexVersion, encodedSparseVector); + assertEquals(expectedMagnitude, decodedMagnitude, 0.0f); + assertArrayEquals( + "Decoded sparse vector dims are not equal to their original!", + expectedDims, + decodedDims + ); assertArrayEquals( - "Decoded dense vector values are not equal to their original.", + "Decoded sparse vector values are not equal to their original.", expectedValues, decodedValues, 0.001f ); } - public void testSparseVectorEncodingDecoding() { + public void testSparseVectorEncodingDecodingBefore_V_7_5_0() { + Version indexVersion = Version.V_7_4_0; int dimCount = randomIntBetween(0, 100); float[] expectedValues = new float[dimCount]; int[] expectedDims = randomUniqueDims(dimCount); @@ -52,9 +78,9 @@ public void testSparseVectorEncodingDecoding() { ); // test that values that went through encoding and decoding are equal to their original - BytesRef encodedSparseVector = VectorEncoderDecoder.encodeSparseVector(expectedDims, expectedValues, dimCount); - int[] decodedDims = VectorEncoderDecoder.decodeSparseVectorDims(encodedSparseVector); - float[] decodedValues = VectorEncoderDecoder.decodeSparseVector(encodedSparseVector); + BytesRef encodedSparseVector = VectorEncoderDecoder.encodeSparseVector(indexVersion, expectedDims, expectedValues, dimCount); + int[] decodedDims = VectorEncoderDecoder.decodeSparseVectorDims(indexVersion, encodedSparseVector); + float[] decodedValues = VectorEncoderDecoder.decodeSparseVector(indexVersion, encodedSparseVector); assertArrayEquals( "Decoded sparse vector dims are not equal to their original!", expectedDims, @@ -69,23 +95,28 @@ public void testSparseVectorEncodingDecoding() { } // imitates the code in DenseVectorFieldMapper::parse - public static BytesRef mockEncodeDenseVector(float[] values) { - final short INT_BYTES = VectorEncoderDecoder.INT_BYTES; - byte[] buf = new byte[INT_BYTES * values.length]; - int offset = 0; - int intValue; - for (float value: values) { - intValue = Float.floatToIntBits(value); - buf[offset++] = (byte) (intValue >> 24); - buf[offset++] = (byte) (intValue >> 16); - buf[offset++] = (byte) (intValue >> 8); - buf[offset++] = (byte) intValue; + public static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) { + byte[] bytes = indexVersion.onOrAfter(Version.V_7_5_0) + ? new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES] + : new byte[VectorEncoderDecoder.INT_BYTES * values.length]; + double dotProduct = 0f; + + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + for (float value : values) { + byteBuffer.putFloat(value); + dotProduct += value * value; + } + + if (indexVersion.onOrAfter(Version.V_7_5_0)) { + // encode vector magnitude at the end + float vectorMagnitude = (float) Math.sqrt(dotProduct); + byteBuffer.putFloat(vectorMagnitude); } - return new BytesRef(buf, 0, offset); + return new BytesRef(bytes); } // generate unique random dims - private int[] randomUniqueDims(int dimCount) { + private static int[] randomUniqueDims(int dimCount) { int[] values = new int[dimCount]; Set usedValues = new HashSet<>(); int value; diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java index 87f8f83c06bd7..343db845e6e3b 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java @@ -7,12 +7,17 @@ package org.elasticsearch.xpack.vectors.query; import org.apache.lucene.util.BytesRef; -import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; +import org.elasticsearch.Version; +import org.elasticsearch.script.ScoreScript; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.CosineSimilarity; -import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProductSparse; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.CosineSimilaritySparse; +import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProduct; +import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProductSparse; +import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1Norm; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1NormSparse; +import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2Norm; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2NormSparse; import java.util.Arrays; @@ -21,65 +26,85 @@ import java.util.Map; import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoderTests.mockEncodeDenseVector; -import static org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.dotProduct; -import static org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.l1norm; -import static org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.l2norm; - import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; - public class ScoreScriptUtilsTests extends ESTestCase { + public void testDenseVectorFunctions() { + testDenseVectorFunctions(Version.V_7_4_0); + testDenseVectorFunctions(Version.CURRENT); + } + + private void testDenseVectorFunctions(Version indexVersion) { float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; - BytesRef encodedDocVector = mockEncodeDenseVector(docVector); + BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion); VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript._getIndexVersion()).thenReturn(indexVersion); + List queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f); // test dotProduct - double result = dotProduct(queryVector, dvs); + DotProduct dotProduct = new DotProduct(scoreScript, queryVector); + double result = dotProduct.dotProduct(dvs); assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001); // test cosineSimilarity - CosineSimilarity cosineSimilarity = new CosineSimilarity(queryVector); + CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector); double result2 = cosineSimilarity.cosineSimilarity(dvs); assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001); // test l1Norm - double result3 = l1norm(queryVector, dvs); + L1Norm l1norm = new L1Norm(scoreScript, queryVector); + double result3 = l1norm.l1norm(dvs); assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001); // test l2norm - double result4 = l2norm(queryVector, dvs); + L2Norm l2norm = new L2Norm(scoreScript, queryVector); + double result4 = l2norm.l2norm(dvs); assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001); // test dotProduct fails when queryVector has wrong number of dims List invalidQueryVector = Arrays.asList(0.5, 111.3); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> dotProduct(invalidQueryVector, dvs)); - assertThat(e.getMessage(), containsString("dimensions of the query vector [2] is different from the documents' vectors [5]")); + DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> dotProduct2.dotProduct(dvs)); + assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); // test cosineSimilarity fails when queryVector has wrong number of dims - CosineSimilarity cosineSimilarity2 = new CosineSimilarity(invalidQueryVector); + CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector); e = expectThrows(IllegalArgumentException.class, () -> cosineSimilarity2.cosineSimilarity(dvs)); - assertThat(e.getMessage(), containsString("dimensions of the query vector [2] is different from the documents' vectors [5]")); + assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); // test l1norm fails when queryVector has wrong number of dims - e = expectThrows(IllegalArgumentException.class, () -> l1norm(invalidQueryVector, dvs)); - assertThat(e.getMessage(), containsString("dimensions of the query vector [2] is different from the documents' vectors [5]")); + L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector); + e = expectThrows(IllegalArgumentException.class, () -> l1norm2.l1norm(dvs)); + assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); // test l2norm fails when queryVector has wrong number of dims - e = expectThrows(IllegalArgumentException.class, () -> l2norm(invalidQueryVector, dvs)); - assertThat(e.getMessage(), containsString("dimensions of the query vector [2] is different from the documents' vectors [5]")); + L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector); + e = expectThrows(IllegalArgumentException.class, () -> l2norm2.l2norm(dvs)); + assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); } public void testSparseVectorFunctions() { + testSparseVectorFunctions(Version.V_7_4_0); + testSparseVectorFunctions(Version.CURRENT); + } + + private void testSparseVectorFunctions(Version indexVersion) { int[] docVectorDims = {2, 10, 50, 113, 4545}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; - BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(docVectorDims, docVectorValues, docVectorDims.length); + BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector( + indexVersion, docVectorDims, docVectorValues, docVectorDims.length); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript._getIndexVersion()).thenReturn(indexVersion); + Map queryVector = new HashMap() {{ put("2", 0.5); put("10", 111.3); @@ -89,22 +114,22 @@ public void testSparseVectorFunctions() { }}; // test dotProduct - DotProductSparse docProductSparse = new DotProductSparse(queryVector); + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); double result = docProductSparse.dotProductSparse(dvs); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); // test cosineSimilarity - CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector); + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001); // test l1norm - L1NormSparse l1Norm = new L1NormSparse(queryVector); + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); double result3 = l1Norm.l1normSparse(dvs); assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001); // test l2norm - L2NormSparse l2Norm = new L2NormSparse(queryVector); + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); double result4 = l2Norm.l2normSparse(dvs); assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001); } @@ -113,9 +138,12 @@ public void testSparseVectorMissingDimensions1() { // Document vector's biggest dimension > query vector's biggest dimension int[] docVectorDims = {2, 10, 50, 113, 4545, 4546}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f}; - BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(docVectorDims, docVectorValues, docVectorDims.length); + BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector( + Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); Map queryVector = new HashMap() {{ put("2", 0.5); put("10", 111.3); @@ -126,22 +154,22 @@ public void testSparseVectorMissingDimensions1() { }}; // test dotProduct - DotProductSparse docProductSparse = new DotProductSparse(queryVector); + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); double result = docProductSparse.dotProductSparse(dvs); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); // test cosineSimilarity - CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector); + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001); // test l1norm - L1NormSparse l1Norm = new L1NormSparse(queryVector); + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); double result3 = l1Norm.l1normSparse(dvs); assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001); // test l2norm - L2NormSparse l2Norm = new L2NormSparse(queryVector); + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); double result4 = l2Norm.l2normSparse(dvs); assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001); } @@ -150,9 +178,12 @@ public void testSparseVectorMissingDimensions2() { // Document vector's biggest dimension < query vector's biggest dimension int[] docVectorDims = {2, 10, 50, 113, 4545, 4546}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f}; - BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(docVectorDims, docVectorValues, docVectorDims.length); + BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector( + Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); + ScoreScript scoreScript = mock(ScoreScript.class); + when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); Map queryVector = new HashMap() {{ put("2", 0.5); put("10", 111.3); @@ -163,22 +194,22 @@ public void testSparseVectorMissingDimensions2() { }}; // test dotProduct - DotProductSparse docProductSparse = new DotProductSparse(queryVector); + DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); double result = docProductSparse.dotProductSparse(dvs); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); // test cosineSimilarity - CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector); + CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001); // test l1norm - L1NormSparse l1Norm = new L1NormSparse(queryVector); + L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); double result3 = l1Norm.l1normSparse(dvs); assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001); // test l2norm - L2NormSparse l2Norm = new L2NormSparse(queryVector); + L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); double result4 = l2Norm.l2normSparse(dvs); assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001); }