diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java index 686823c248bb9..e7ca4f1f297a5 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java @@ -162,47 +162,21 @@ public void swap(int i, int j) { }.sort(0, n); } - /** - * Decodes a BytesRef into an array of floats - * @param indexVersion - index Version - * @param vectorBR - dense vector encoded in BytesRef - */ - public static float[] decodeDenseVector(Version indexVersion, BytesRef vectorBR) { - if (vectorBR == null) { - throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); - } - - int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / INT_BYTES : vectorBR.length/ INT_BYTES; - float[] vector = new float[dimCount]; - - ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); - for (int dim = 0; dim < dimCount; dim++) { - vector[dim] = byteBuffer.getFloat(); - } - return vector; + public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) { + return indexVersion.onOrAfter(Version.V_7_4_0) + ? (vectorBR.length - INT_BYTES) / INT_BYTES + : vectorBR.length / INT_BYTES; } /** - * Calculates vector magnitude either by - * decoding last 4 bytes of BytesRef into a vector magnitude or calculating it - * @param indexVersion - index Version - * @param vectorBR - vector encoded in BytesRef - * @param vector - float vector + * Decodes the last 4 bytes of the encoded vector, which contains the vector magnitude. + * NOTE: this function can only be called on vectors from an index version greater than or + * equal to 7.4.0, since vectors created prior to that do not store the magnitude. */ - public static float getVectorMagnitude(Version indexVersion, BytesRef vectorBR, float[] vector) { - if (vectorBR == null) { - throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); - } + public static float decodeVectorMagnitude(Version indexVersion, BytesRef vectorBR) { + assert indexVersion.onOrAfter(Version.V_7_4_0); + ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); + return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4); - if (indexVersion.onOrAfter(Version.V_7_4_0)) { // decode vector magnitude - ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length); - return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4); - } else { // calculate vector magnitude - double dotProduct = 0f; - for (int dim = 0; dim < vector.length; dim++) { - dotProduct += (double) vector[dim] * vector[dim]; - } - return (float) Math.sqrt(dotProduct); - } } } diff --git a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java index fa4b3a15916c6..b97c5d5b35084 100644 --- a/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java +++ b/x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java @@ -8,9 +8,11 @@ package org.elasticsearch.xpack.vectors.query; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; import org.elasticsearch.script.ScoreScript; import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; +import java.nio.ByteBuffer; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -32,16 +34,21 @@ public L1Norm(ScoreScript scoreScript) { } public double l1norm(List queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs){ - BytesRef value = dvs.getEncodedValue(); - float[] docVector = VectorEncoderDecoder.decodeDenseVector(scoreScript._getIndexVersion(), value); - if (queryVector.size() != docVector.length) { + BytesRef vector = dvs.getEncodedValue(); + int vectorLength = VectorEncoderDecoder.denseVectorLength(scoreScript._getIndexVersion(), vector); + if (queryVector.size() != vectorLength) { throw new IllegalArgumentException("Can't calculate l1norm! The number of dimensions of the query vector [" + - queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "]."); + queryVector.size() + "] is different from the documents' vectors [" + vectorLength + "]."); } + Iterator queryVectorIter = queryVector.iterator(); + ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); + double l1norm = 0; - for (int dim = 0; dim < docVector.length; dim++){ - l1norm += Math.abs(queryVectorIter.next().floatValue() - docVector[dim]); + for (int dim = 0; dim < vectorLength; dim++) { + double queryValue = queryVectorIter.next().floatValue(); + double docValue = byteBuffer.getFloat(); + l1norm += Math.abs(queryValue - docValue); } return l1norm; } @@ -55,16 +62,19 @@ public L2Norm(ScoreScript scoreScript) { } public double l2norm(List queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs){ - BytesRef value = dvs.getEncodedValue(); - float[] docVector = VectorEncoderDecoder.decodeDenseVector(scoreScript._getIndexVersion(), value); - if (queryVector.size() != docVector.length) { + BytesRef vector = dvs.getEncodedValue(); + int vectorLength = VectorEncoderDecoder.denseVectorLength(scoreScript._getIndexVersion(), vector); + if (queryVector.size() != vectorLength) { throw new IllegalArgumentException("Can't calculate l2norm! The number of dimensions of the query vector [" + - queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "]."); + queryVector.size() + "] is different from the documents' vectors [" + vectorLength + "]."); } + + ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); Iterator queryVectorIter = queryVector.iterator(); + double l2norm = 0; - for (int dim = 0; dim < docVector.length; dim++){ - double diff = queryVectorIter.next().floatValue() - docVector[dim]; + for (int dim = 0; dim < vectorLength; dim++) { + double diff = queryVectorIter.next().floatValue() - byteBuffer.getFloat(); l2norm += diff * diff; } return Math.sqrt(l2norm); @@ -77,14 +87,23 @@ public static final class DotProduct { public DotProduct(ScoreScript scoreScript){ this.scoreScript = scoreScript; } + public double dotProduct(List queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs){ - BytesRef value = dvs.getEncodedValue(); - float[] docVector = VectorEncoderDecoder.decodeDenseVector(scoreScript._getIndexVersion(), value); - if (queryVector.size() != docVector.length) { - throw new IllegalArgumentException("Can't calculate dotProduct! The number of dimensions of the query vector [" + - queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "]."); + BytesRef vector = dvs.getEncodedValue(); + int vectorLength = VectorEncoderDecoder.denseVectorLength(scoreScript._getIndexVersion(), vector); + if (queryVector.size() != vectorLength) { + throw new IllegalArgumentException("Can't calculate dotPRoduct! The number of dimensions of the query vector [" + + queryVector.size() + "] is different from the documents' vectors [" + vectorLength + "]."); } - return intDotProduct(queryVector, docVector); + + Iterator queryVectorIter = queryVector.iterator(); + ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); + + double dotProduct = 0; + for (int dim = 0; dim < vectorLength; dim++) { + dotProduct += queryVectorIter.next().floatValue() * byteBuffer.getFloat(); + } + return dotProduct; } } @@ -108,28 +127,35 @@ public CosineSimilarity(ScoreScript scoreScript, List queryVector) { } public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { - BytesRef value = dvs.getEncodedValue(); - float[] docVector = VectorEncoderDecoder.decodeDenseVector(scoreScript._getIndexVersion(), value); - if (queryVector.size() != docVector.length) { + BytesRef vector = dvs.getEncodedValue(); + int vectorLength = VectorEncoderDecoder.denseVectorLength(scoreScript._getIndexVersion(), vector); + if (queryVector.size() != vectorLength) { throw new IllegalArgumentException("Can't calculate cosineSimilarity! The number of dimensions of the query vector [" + - queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "]."); + queryVector.size() + "] is different from the documents' vectors [" + vectorLength + "]."); } - float docVectorMagnitude = VectorEncoderDecoder.getVectorMagnitude(scoreScript._getIndexVersion(), value, docVector); - double docQueryDotProduct = intDotProduct(queryVector, docVector); - return docQueryDotProduct / (docVectorMagnitude * queryVectorMagnitude); - } - } - private static double intDotProduct(List v1, float[] v2){ - double v1v2DotProduct = 0; - Iterator v1Iter = v1.iterator(); - for (int dim = 0; dim < v2.length; dim++) { - v1v2DotProduct += v1Iter.next().floatValue() * v2[dim]; + Iterator queryVectorIter = queryVector.iterator(); + ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); + + double dotProduct = 0.0; + double docVectorMagnitude = 0.0f; + if (scoreScript._getIndexVersion().onOrAfter(Version.V_7_4_0)) { + for (int dim = 0; dim < vectorLength; dim++) { + dotProduct += queryVectorIter.next().floatValue() * byteBuffer.getFloat(); + } + docVectorMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(scoreScript._getIndexVersion(), vector); + } else { + for (int dim = 0; dim < vectorLength; dim++) { + float docValue = byteBuffer.getFloat(); + dotProduct += queryVectorIter.next().floatValue() * docValue; + docVectorMagnitude += docValue * docValue; + } + docVectorMagnitude = (float) Math.sqrt(docVectorMagnitude); + } + return dotProduct / (docVectorMagnitude * queryVectorMagnitude); } - return v1v2DotProduct; } - //**************FUNCTIONS FOR SPARSE VECTORS // Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings. // Also, constructors for some functions accept queryVector to calculate and cache queryVectorMagnitude only once @@ -273,8 +299,18 @@ public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDoc BytesRef value = dvs.getEncodedValue(); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), value); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), value); - float docVectorMagnitude = VectorEncoderDecoder.getVectorMagnitude(scoreScript._getIndexVersion(), value, docValues); + double docQueryDotProduct = intDotProductSparse(queryValues, queryDims, docValues, docDims); + double docVectorMagnitude = 0.0f; + if (scoreScript._getIndexVersion().onOrAfter(Version.V_7_4_0)) { + docVectorMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(scoreScript._getIndexVersion(), value); + } else { + for (float docValue : docValues) { + docVectorMagnitude += docValue * docValue; + } + docVectorMagnitude = (float) Math.sqrt(docVectorMagnitude); + } + return docQueryDotProduct / (docVectorMagnitude * queryVectorMagnitude); } } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java index 7db3b1d1ddfac..56c5d4d743db7 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java @@ -30,6 +30,7 @@ import org.elasticsearch.xpack.vectors.Vectors; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.Collection; import static org.hamcrest.Matchers.containsString; @@ -95,8 +96,8 @@ public void testDefaults() throws Exception { assertThat(fields[0], instanceOf(BinaryDocValuesField.class)); // assert that after decoding the indexed value is equal to expected BytesRef vectorBR = fields[0].binaryValue(); - float[] decodedValues = VectorEncoderDecoder.decodeDenseVector(indexVersion, vectorBR); - float decodedMagnitude = VectorEncoderDecoder.getVectorMagnitude(indexVersion, vectorBR, decodedValues); + float[] decodedValues = decodeDenseVector(indexVersion, vectorBR); + float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(indexVersion, vectorBR); assertEquals(expectedMagnitude, decodedMagnitude, 0.001f); assertArrayEquals( "Decoded dense vector values is not equal to the indexed one.", @@ -133,7 +134,7 @@ public void testAddDocumentsToIndexBefore_V_7_4_0() throws Exception { assertThat(fields[0], instanceOf(BinaryDocValuesField.class)); // assert that after decoding the indexed value is equal to expected BytesRef vectorBR = fields[0].binaryValue(); - float[] decodedValues = VectorEncoderDecoder.decodeDenseVector(indexVersion, vectorBR); + float[] decodedValues = decodeDenseVector(indexVersion, vectorBR); assertArrayEquals( "Decoded dense vector values is not equal to the indexed one.", validVector, @@ -142,6 +143,17 @@ public void testAddDocumentsToIndexBefore_V_7_4_0() throws Exception { ); } + private static float[] decodeDenseVector(Version indexVersion, BytesRef encodedVector) { + int dimCount = VectorEncoderDecoder.denseVectorLength(indexVersion, encodedVector); + float[] vector = new float[dimCount]; + + ByteBuffer byteBuffer = ByteBuffer.wrap(encodedVector.bytes, encodedVector.offset, encodedVector.length); + for (int dim = 0; dim < dimCount; dim++) { + vector[dim] = byteBuffer.getFloat(); + } + return vector; + } + public void testDocumentsWithIncorrectDims() throws Exception { IndexService indexService = createIndex("test-index"); int dims = 3; diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapperTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapperTests.java index aece80588b8d4..8a985bbaf7933 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapperTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/SparseVectorFieldMapperTests.java @@ -113,7 +113,7 @@ public void testDefaults() throws Exception { decodedValues, 0.001f ); - float decodedMagnitude = VectorEncoderDecoder.getVectorMagnitude(indexVersion, vectorBR, decodedValues); + float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(indexVersion, vectorBR); assertEquals(expectedMagnitude, decodedMagnitude, 0.001f); } diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java index ba7de2ad74528..631dcb8dd0da0 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoderTests.java @@ -17,48 +17,6 @@ public class VectorEncoderDecoderTests extends ESTestCase { - public void testDenseVectorEncodingDecoding() { - Version indexVersion = Version.CURRENT; - int dimCount = randomIntBetween(0, DenseVectorFieldMapper.MAX_DIMS_COUNT); - float[] expectedValues = new float[dimCount]; - double dotProduct = 0f; - for (int i = 0; i < dimCount; i++) { - expectedValues[i] = randomFloat(); - dotProduct += expectedValues[i] * expectedValues[i]; - } - float expectedMagnitude = (float) Math.sqrt(dotProduct); - - // test that values that went through encoding and decoding are equal to their original - BytesRef encodedDenseVector = mockEncodeDenseVector(expectedValues); - float[] decodedValues = VectorEncoderDecoder.decodeDenseVector(indexVersion, encodedDenseVector); - float decodedMagnitude = VectorEncoderDecoder.getVectorMagnitude(indexVersion, encodedDenseVector, decodedValues); - assertEquals(expectedMagnitude, decodedMagnitude, 0.0f); - assertArrayEquals( - "Decoded dense vector values are not equal to their original.", - expectedValues, - decodedValues, - 0.001f - ); - } - - public void testDenseVectorEncodingDecodingBefore7_4() { - Version indexVersion = Version.V_7_3_0; - int dimCount = randomIntBetween(0, DenseVectorFieldMapper.MAX_DIMS_COUNT); - float[] expectedValues = new float[dimCount]; - for (int i = 0; i < dimCount; i++) { - expectedValues[i] = randomFloat(); - } - // test that values that went through encoding and decoding are equal to their original - BytesRef encodedDenseVector = mockEncodeDenseVectorBefore7_4(expectedValues); - float[] decodedValues = VectorEncoderDecoder.decodeDenseVector(indexVersion, encodedDenseVector); - assertArrayEquals( - "Decoded dense vector values are not equal to their original.", - expectedValues, - decodedValues, - 0.001f - ); - } - public void testSparseVectorEncodingDecoding() { Version indexVersion = Version.CURRENT; int dimCount = randomIntBetween(0, 100); @@ -85,7 +43,7 @@ public void testSparseVectorEncodingDecoding() { BytesRef encodedSparseVector = VectorEncoderDecoder.encodeSparseVector(indexVersion, expectedDims, expectedValues, dimCount); int[] decodedDims = VectorEncoderDecoder.decodeSparseVectorDims(indexVersion, encodedSparseVector); float[] decodedValues = VectorEncoderDecoder.decodeSparseVector(indexVersion, encodedSparseVector); - float decodedMagnitude = VectorEncoderDecoder.getVectorMagnitude(indexVersion, encodedSparseVector, decodedValues); + float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(indexVersion, encodedSparseVector); assertEquals(expectedMagnitude, decodedMagnitude, 0.0f); assertArrayEquals( "Decoded sparse vector dims are not equal to their original!", @@ -137,29 +95,24 @@ public void testSparseVectorEncodingDecodingBefore7_4() { } // imitates the code in DenseVectorFieldMapper::parse - public static BytesRef mockEncodeDenseVector(float[] values) { - byte[] bytes = new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES]; + public static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) { + byte[] bytes = indexVersion.onOrAfter(Version.V_7_4_0) + ? new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES] + : new byte[VectorEncoderDecoder.INT_BYTES * values.length]; double dotProduct = 0f; + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); for (float value : values) { byteBuffer.putFloat(value); dotProduct += value * value; } - // encode vector magnitude at the end - float vectorMagnitude = (float) Math.sqrt(dotProduct); - byteBuffer.putFloat(vectorMagnitude); - return new BytesRef(bytes); - } - // imitates the code in DenseVectorFieldMapper::parse before version 7.4 - public static BytesRef mockEncodeDenseVectorBefore7_4(float[] values) { - byte[] bytes = new byte[VectorEncoderDecoder.INT_BYTES * values.length]; - ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); - for (float value : values) { - byteBuffer.putFloat(value); + if (indexVersion.onOrAfter(Version.V_7_4_0)) { + // encode vector magnitude at the end + float vectorMagnitude = (float) Math.sqrt(dotProduct); + byteBuffer.putFloat(vectorMagnitude); } return new BytesRef(bytes); - } // generate unique random dims diff --git a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java index cc2a82a47971c..c34c92e248aa8 100644 --- a/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java +++ b/x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java @@ -9,15 +9,15 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.Version; import org.elasticsearch.script.ScoreScript; -import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProduct; +import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.CosineSimilarity; -import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1Norm; -import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2Norm; -import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProductSparse; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.CosineSimilaritySparse; +import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProduct; +import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProductSparse; +import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1Norm; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1NormSparse; +import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2Norm; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2NormSparse; import java.util.Arrays; @@ -26,21 +26,25 @@ import java.util.Map; import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoderTests.mockEncodeDenseVector; - import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; - public class ScoreScriptUtilsTests extends ESTestCase { + public void testDenseVectorFunctions() { + testDenseVectorFunctions(Version.V_7_3_0); + testDenseVectorFunctions(Version.CURRENT); + } + + private void testDenseVectorFunctions(Version indexVersion) { float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; - BytesRef encodedDocVector = mockEncodeDenseVector(docVector); + BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion); VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); ScoreScript scoreScript = mock(ScoreScript.class); - when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); + when(scoreScript._getIndexVersion()).thenReturn(indexVersion); List queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f); @@ -84,14 +88,19 @@ public void testDenseVectorFunctions() { } public void testSparseVectorFunctions() { + testSparseVectorFunctions(Version.V_7_3_0); + testSparseVectorFunctions(Version.CURRENT); + } + + private void testSparseVectorFunctions(Version indexVersion) { int[] docVectorDims = {2, 10, 50, 113, 4545}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector( - Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); + indexVersion, docVectorDims, docVectorValues, docVectorDims.length); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); when(dvs.getEncodedValue()).thenReturn(encodedDocVector); ScoreScript scoreScript = mock(ScoreScript.class); - when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); + when(scoreScript._getIndexVersion()).thenReturn(indexVersion); Map queryVector = new HashMap() {{ put("2", 0.5);