Skip to content

Commit

Permalink
Combine vector decoding and function computation. (#46103)
Browse files Browse the repository at this point in the history
This commit updates the dense vector functions like `cosineSimilarity` to
decode the document vector and compute the result at the same time. Previously,
we would fully decode the vector into an array, then calculate the function.
  • Loading branch information
jtibshirani authored Aug 29, 2019
1 parent d6e9aa1 commit 871af21
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,47 +162,21 @@ public void swap(int i, int j) {
}.sort(0, n);
}

/**
* Decodes a BytesRef into an array of floats
* @param indexVersion - index Version
* @param vectorBR - dense vector encoded in BytesRef
*/
public static float[] decodeDenseVector(Version indexVersion, BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}

int dimCount = indexVersion.onOrAfter(Version.V_7_4_0) ? (vectorBR.length - INT_BYTES) / INT_BYTES : vectorBR.length/ INT_BYTES;
float[] vector = new float[dimCount];

ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
for (int dim = 0; dim < dimCount; dim++) {
vector[dim] = byteBuffer.getFloat();
}
return vector;
public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) {
return indexVersion.onOrAfter(Version.V_7_4_0)
? (vectorBR.length - INT_BYTES) / INT_BYTES
: vectorBR.length / INT_BYTES;
}

/**
* Calculates vector magnitude either by
* decoding last 4 bytes of BytesRef into a vector magnitude or calculating it
* @param indexVersion - index Version
* @param vectorBR - vector encoded in BytesRef
* @param vector - float vector
* Decodes the last 4 bytes of the encoded vector, which contains the vector magnitude.
* NOTE: this function can only be called on vectors from an index version greater than or
* equal to 7.4.0, since vectors created prior to that do not store the magnitude.
*/
public static float getVectorMagnitude(Version indexVersion, BytesRef vectorBR, float[] vector) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
public static float decodeVectorMagnitude(Version indexVersion, BytesRef vectorBR) {
assert indexVersion.onOrAfter(Version.V_7_4_0);
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4);

if (indexVersion.onOrAfter(Version.V_7_4_0)) { // decode vector magnitude
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4);
} else { // calculate vector magnitude
double dotProduct = 0f;
for (int dim = 0; dim < vector.length; dim++) {
dotProduct += (double) vector[dim] * vector[dim];
}
return (float) Math.sqrt(dotProduct);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
package org.elasticsearch.xpack.vectors.query;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;

import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand All @@ -32,16 +34,21 @@ public L1Norm(ScoreScript scoreScript) {
}

public double l1norm(List<Number> queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs){
BytesRef value = dvs.getEncodedValue();
float[] docVector = VectorEncoderDecoder.decodeDenseVector(scoreScript._getIndexVersion(), value);
if (queryVector.size() != docVector.length) {
BytesRef vector = dvs.getEncodedValue();
int vectorLength = VectorEncoderDecoder.denseVectorLength(scoreScript._getIndexVersion(), vector);
if (queryVector.size() != vectorLength) {
throw new IllegalArgumentException("Can't calculate l1norm! The number of dimensions of the query vector [" +
queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "].");
queryVector.size() + "] is different from the documents' vectors [" + vectorLength + "].");
}

Iterator<Number> queryVectorIter = queryVector.iterator();
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);

double l1norm = 0;
for (int dim = 0; dim < docVector.length; dim++){
l1norm += Math.abs(queryVectorIter.next().floatValue() - docVector[dim]);
for (int dim = 0; dim < vectorLength; dim++) {
double queryValue = queryVectorIter.next().floatValue();
double docValue = byteBuffer.getFloat();
l1norm += Math.abs(queryValue - docValue);
}
return l1norm;
}
Expand All @@ -55,16 +62,19 @@ public L2Norm(ScoreScript scoreScript) {
}

public double l2norm(List<Number> queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs){
BytesRef value = dvs.getEncodedValue();
float[] docVector = VectorEncoderDecoder.decodeDenseVector(scoreScript._getIndexVersion(), value);
if (queryVector.size() != docVector.length) {
BytesRef vector = dvs.getEncodedValue();
int vectorLength = VectorEncoderDecoder.denseVectorLength(scoreScript._getIndexVersion(), vector);
if (queryVector.size() != vectorLength) {
throw new IllegalArgumentException("Can't calculate l2norm! The number of dimensions of the query vector [" +
queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "].");
queryVector.size() + "] is different from the documents' vectors [" + vectorLength + "].");
}

ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
Iterator<Number> queryVectorIter = queryVector.iterator();

double l2norm = 0;
for (int dim = 0; dim < docVector.length; dim++){
double diff = queryVectorIter.next().floatValue() - docVector[dim];
for (int dim = 0; dim < vectorLength; dim++) {
double diff = queryVectorIter.next().floatValue() - byteBuffer.getFloat();
l2norm += diff * diff;
}
return Math.sqrt(l2norm);
Expand All @@ -77,14 +87,23 @@ public static final class DotProduct {
public DotProduct(ScoreScript scoreScript){
this.scoreScript = scoreScript;
}

public double dotProduct(List<Number> queryVector, VectorScriptDocValues.DenseVectorScriptDocValues dvs){
BytesRef value = dvs.getEncodedValue();
float[] docVector = VectorEncoderDecoder.decodeDenseVector(scoreScript._getIndexVersion(), value);
if (queryVector.size() != docVector.length) {
throw new IllegalArgumentException("Can't calculate dotProduct! The number of dimensions of the query vector [" +
queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "].");
BytesRef vector = dvs.getEncodedValue();
int vectorLength = VectorEncoderDecoder.denseVectorLength(scoreScript._getIndexVersion(), vector);
if (queryVector.size() != vectorLength) {
throw new IllegalArgumentException("Can't calculate dotPRoduct! The number of dimensions of the query vector [" +
queryVector.size() + "] is different from the documents' vectors [" + vectorLength + "].");
}
return intDotProduct(queryVector, docVector);

Iterator<Number> queryVectorIter = queryVector.iterator();
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);

double dotProduct = 0;
for (int dim = 0; dim < vectorLength; dim++) {
dotProduct += queryVectorIter.next().floatValue() * byteBuffer.getFloat();
}
return dotProduct;
}
}

Expand All @@ -108,28 +127,35 @@ public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector) {
}

public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
BytesRef value = dvs.getEncodedValue();
float[] docVector = VectorEncoderDecoder.decodeDenseVector(scoreScript._getIndexVersion(), value);
if (queryVector.size() != docVector.length) {
BytesRef vector = dvs.getEncodedValue();
int vectorLength = VectorEncoderDecoder.denseVectorLength(scoreScript._getIndexVersion(), vector);
if (queryVector.size() != vectorLength) {
throw new IllegalArgumentException("Can't calculate cosineSimilarity! The number of dimensions of the query vector [" +
queryVector.size() + "] is different from the documents' vectors [" + docVector.length + "].");
queryVector.size() + "] is different from the documents' vectors [" + vectorLength + "].");
}
float docVectorMagnitude = VectorEncoderDecoder.getVectorMagnitude(scoreScript._getIndexVersion(), value, docVector);
double docQueryDotProduct = intDotProduct(queryVector, docVector);
return docQueryDotProduct / (docVectorMagnitude * queryVectorMagnitude);
}
}

private static double intDotProduct(List<Number> v1, float[] v2){
double v1v2DotProduct = 0;
Iterator<Number> v1Iter = v1.iterator();
for (int dim = 0; dim < v2.length; dim++) {
v1v2DotProduct += v1Iter.next().floatValue() * v2[dim];
Iterator<Number> queryVectorIter = queryVector.iterator();
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);

double dotProduct = 0.0;
double docVectorMagnitude = 0.0f;
if (scoreScript._getIndexVersion().onOrAfter(Version.V_7_4_0)) {
for (int dim = 0; dim < vectorLength; dim++) {
dotProduct += queryVectorIter.next().floatValue() * byteBuffer.getFloat();
}
docVectorMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(scoreScript._getIndexVersion(), vector);
} else {
for (int dim = 0; dim < vectorLength; dim++) {
float docValue = byteBuffer.getFloat();
dotProduct += queryVectorIter.next().floatValue() * docValue;
docVectorMagnitude += docValue * docValue;
}
docVectorMagnitude = (float) Math.sqrt(docVectorMagnitude);
}
return dotProduct / (docVectorMagnitude * queryVectorMagnitude);
}
return v1v2DotProduct;
}


//**************FUNCTIONS FOR SPARSE VECTORS
// Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings.
// Also, constructors for some functions accept queryVector to calculate and cache queryVectorMagnitude only once
Expand Down Expand Up @@ -273,8 +299,18 @@ public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDoc
BytesRef value = dvs.getEncodedValue();
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), value);
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), value);
float docVectorMagnitude = VectorEncoderDecoder.getVectorMagnitude(scoreScript._getIndexVersion(), value, docValues);

double docQueryDotProduct = intDotProductSparse(queryValues, queryDims, docValues, docDims);
double docVectorMagnitude = 0.0f;
if (scoreScript._getIndexVersion().onOrAfter(Version.V_7_4_0)) {
docVectorMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(scoreScript._getIndexVersion(), value);
} else {
for (float docValue : docValues) {
docVectorMagnitude += docValue * docValue;
}
docVectorMagnitude = (float) Math.sqrt(docVectorMagnitude);
}

return docQueryDotProduct / (docVectorMagnitude * queryVectorMagnitude);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.elasticsearch.xpack.vectors.Vectors;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collection;

import static org.hamcrest.Matchers.containsString;
Expand Down Expand Up @@ -95,8 +96,8 @@ public void testDefaults() throws Exception {
assertThat(fields[0], instanceOf(BinaryDocValuesField.class));
// assert that after decoding the indexed value is equal to expected
BytesRef vectorBR = fields[0].binaryValue();
float[] decodedValues = VectorEncoderDecoder.decodeDenseVector(indexVersion, vectorBR);
float decodedMagnitude = VectorEncoderDecoder.getVectorMagnitude(indexVersion, vectorBR, decodedValues);
float[] decodedValues = decodeDenseVector(indexVersion, vectorBR);
float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(indexVersion, vectorBR);
assertEquals(expectedMagnitude, decodedMagnitude, 0.001f);
assertArrayEquals(
"Decoded dense vector values is not equal to the indexed one.",
Expand Down Expand Up @@ -133,7 +134,7 @@ public void testAddDocumentsToIndexBefore_V_7_4_0() throws Exception {
assertThat(fields[0], instanceOf(BinaryDocValuesField.class));
// assert that after decoding the indexed value is equal to expected
BytesRef vectorBR = fields[0].binaryValue();
float[] decodedValues = VectorEncoderDecoder.decodeDenseVector(indexVersion, vectorBR);
float[] decodedValues = decodeDenseVector(indexVersion, vectorBR);
assertArrayEquals(
"Decoded dense vector values is not equal to the indexed one.",
validVector,
Expand All @@ -142,6 +143,17 @@ public void testAddDocumentsToIndexBefore_V_7_4_0() throws Exception {
);
}

private static float[] decodeDenseVector(Version indexVersion, BytesRef encodedVector) {
int dimCount = VectorEncoderDecoder.denseVectorLength(indexVersion, encodedVector);
float[] vector = new float[dimCount];

ByteBuffer byteBuffer = ByteBuffer.wrap(encodedVector.bytes, encodedVector.offset, encodedVector.length);
for (int dim = 0; dim < dimCount; dim++) {
vector[dim] = byteBuffer.getFloat();
}
return vector;
}

public void testDocumentsWithIncorrectDims() throws Exception {
IndexService indexService = createIndex("test-index");
int dims = 3;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public void testDefaults() throws Exception {
decodedValues,
0.001f
);
float decodedMagnitude = VectorEncoderDecoder.getVectorMagnitude(indexVersion, vectorBR, decodedValues);
float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(indexVersion, vectorBR);
assertEquals(expectedMagnitude, decodedMagnitude, 0.001f);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,6 @@

public class VectorEncoderDecoderTests extends ESTestCase {

public void testDenseVectorEncodingDecoding() {
Version indexVersion = Version.CURRENT;
int dimCount = randomIntBetween(0, DenseVectorFieldMapper.MAX_DIMS_COUNT);
float[] expectedValues = new float[dimCount];
double dotProduct = 0f;
for (int i = 0; i < dimCount; i++) {
expectedValues[i] = randomFloat();
dotProduct += expectedValues[i] * expectedValues[i];
}
float expectedMagnitude = (float) Math.sqrt(dotProduct);

// test that values that went through encoding and decoding are equal to their original
BytesRef encodedDenseVector = mockEncodeDenseVector(expectedValues);
float[] decodedValues = VectorEncoderDecoder.decodeDenseVector(indexVersion, encodedDenseVector);
float decodedMagnitude = VectorEncoderDecoder.getVectorMagnitude(indexVersion, encodedDenseVector, decodedValues);
assertEquals(expectedMagnitude, decodedMagnitude, 0.0f);
assertArrayEquals(
"Decoded dense vector values are not equal to their original.",
expectedValues,
decodedValues,
0.001f
);
}

public void testDenseVectorEncodingDecodingBefore7_4() {
Version indexVersion = Version.V_7_3_0;
int dimCount = randomIntBetween(0, DenseVectorFieldMapper.MAX_DIMS_COUNT);
float[] expectedValues = new float[dimCount];
for (int i = 0; i < dimCount; i++) {
expectedValues[i] = randomFloat();
}
// test that values that went through encoding and decoding are equal to their original
BytesRef encodedDenseVector = mockEncodeDenseVectorBefore7_4(expectedValues);
float[] decodedValues = VectorEncoderDecoder.decodeDenseVector(indexVersion, encodedDenseVector);
assertArrayEquals(
"Decoded dense vector values are not equal to their original.",
expectedValues,
decodedValues,
0.001f
);
}

public void testSparseVectorEncodingDecoding() {
Version indexVersion = Version.CURRENT;
int dimCount = randomIntBetween(0, 100);
Expand All @@ -85,7 +43,7 @@ public void testSparseVectorEncodingDecoding() {
BytesRef encodedSparseVector = VectorEncoderDecoder.encodeSparseVector(indexVersion, expectedDims, expectedValues, dimCount);
int[] decodedDims = VectorEncoderDecoder.decodeSparseVectorDims(indexVersion, encodedSparseVector);
float[] decodedValues = VectorEncoderDecoder.decodeSparseVector(indexVersion, encodedSparseVector);
float decodedMagnitude = VectorEncoderDecoder.getVectorMagnitude(indexVersion, encodedSparseVector, decodedValues);
float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(indexVersion, encodedSparseVector);
assertEquals(expectedMagnitude, decodedMagnitude, 0.0f);
assertArrayEquals(
"Decoded sparse vector dims are not equal to their original!",
Expand Down Expand Up @@ -137,29 +95,24 @@ public void testSparseVectorEncodingDecodingBefore7_4() {
}

// imitates the code in DenseVectorFieldMapper::parse
public static BytesRef mockEncodeDenseVector(float[] values) {
byte[] bytes = new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES];
public static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) {
byte[] bytes = indexVersion.onOrAfter(Version.V_7_4_0)
? new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES]
: new byte[VectorEncoderDecoder.INT_BYTES * values.length];
double dotProduct = 0f;

ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
for (float value : values) {
byteBuffer.putFloat(value);
dotProduct += value * value;
}
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
byteBuffer.putFloat(vectorMagnitude);
return new BytesRef(bytes);
}

// imitates the code in DenseVectorFieldMapper::parse before version 7.4
public static BytesRef mockEncodeDenseVectorBefore7_4(float[] values) {
byte[] bytes = new byte[VectorEncoderDecoder.INT_BYTES * values.length];
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
for (float value : values) {
byteBuffer.putFloat(value);
if (indexVersion.onOrAfter(Version.V_7_4_0)) {
// encode vector magnitude at the end
float vectorMagnitude = (float) Math.sqrt(dotProduct);
byteBuffer.putFloat(vectorMagnitude);
}
return new BytesRef(bytes);

}

// generate unique random dims
Expand Down
Loading

0 comments on commit 871af21

Please sign in to comment.