Skip to content

Commit

Permalink
Use float instead of double for query vectors. (#46004)
Browse files Browse the repository at this point in the history
Currently, when using script_score functions like cosineSimilarity, the query
vector is treated as an array of doubles. Since the stored document vectors use
floats, it seems like the least surprising behavior for the query vectors to
also be float arrays.

In addition to improving consistency, this change may help with some
optimizations we have been considering around vector dot product.
  • Loading branch information
jtibshirani authored Aug 28, 2019
1 parent 056c2bd commit 8d16c9b
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ setup:

- match: {hits.hits.1._id: "2"}
- gte: {hits.hits.1._score: 12.29}
- lte: {hits.hits.1._score: 12.30}
- lte: {hits.hits.1._score: 12.31}

- match: {hits.hits.2._id: "3"}
- gte: {hits.hits.2._score: 0.00}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ setup:

- match: {hits.hits.1._id: "2"}
- gte: {hits.hits.1._score: 12.29}
- lte: {hits.hits.1._score: 12.30}
- lte: {hits.hits.1._score: 12.31}

- match: {hits.hits.2._id: "3"}
- gte: {hits.hits.2._score: 0.00}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public void swap(int i, int j) {
* @param values - values for the sparse query vector
* @param n - number of dimensions
*/
public static void sortSparseDimsDoubleValues(int[] dims, double[] values, int n) {
public static void sortSparseDimsFloatValues(int[] dims, float[] values, int n) {
new InPlaceMergeSorter() {
@Override
public int compare(int i, int j) {
Expand All @@ -143,7 +143,7 @@ public void swap(int i, int j) {
dims[i] = dims[j];
dims[j] = tempDim;

double tempValue = values[j];
float tempValue = values[j];
values[j] = values[i];
values[i] = tempValue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import java.util.List;
import java.util.Map;

import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsDoubleValues;
import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsFloatValues;

public class ScoreScriptUtils {

Expand All @@ -37,7 +37,7 @@ public static double l1norm(List<Number> queryVector, VectorScriptDocValues.Dens
Iterator<Number> queryVectorIter = queryVector.iterator();
double l1norm = 0;
for (int dim = 0; dim < docVector.length; dim++){
l1norm += Math.abs(queryVectorIter.next().doubleValue() - docVector[dim]);
l1norm += Math.abs(queryVectorIter.next().floatValue() - docVector[dim]);
}
return l1norm;
}
Expand All @@ -59,7 +59,7 @@ public static double l2norm(List<Number> queryVector, VectorScriptDocValues.Dens
Iterator<Number> queryVectorIter = queryVector.iterator();
double l2norm = 0;
for (int dim = 0; dim < docVector.length; dim++){
double diff = queryVectorIter.next().doubleValue() - docVector[dim];
double diff = queryVectorIter.next().floatValue() - docVector[dim];
l2norm += diff * diff;
}
return Math.sqrt(l2norm);
Expand Down Expand Up @@ -97,11 +97,11 @@ public static final class CosineSimilarity {
// calculate queryVectorMagnitude once per query execution
public CosineSimilarity(List<Number> queryVector) {
this.queryVector = queryVector;
double doubleValue;

double dotProduct = 0;
for (Number value : queryVector) {
doubleValue = value.doubleValue();
dotProduct += doubleValue * doubleValue;
float floatValue = value.floatValue();
dotProduct += floatValue * floatValue;
}
this.queryVectorMagnitude = Math.sqrt(dotProduct);
}
Expand Down Expand Up @@ -130,7 +130,7 @@ private static double intDotProduct(List<Number> v1, float[] v2){
double v1v2DotProduct = 0;
Iterator<Number> v1Iter = v1.iterator();
for (int dim = 0; dim < v2.length; dim++) {
v1v2DotProduct += v1Iter.next().doubleValue() * v2[dim];
v1v2DotProduct += v1Iter.next().floatValue() * v2[dim];
}
return v1v2DotProduct;
}
Expand All @@ -139,15 +139,15 @@ private static double intDotProduct(List<Number> v1, float[] v2){
//**************FUNCTIONS FOR SPARSE VECTORS

public static class VectorSparseFunctions {
final double[] queryValues;
final float[] queryValues;
final int[] queryDims;

// prepare queryVector once per script execution
// queryVector represents a map of dimensions to values
public VectorSparseFunctions(Map<String, Number> queryVector) {
//break vector into two arrays dims and values
int n = queryVector.size();
queryValues = new double[n];
queryValues = new float[n];
queryDims = new int[n];
int i = 0;
for (Map.Entry<String, Number> dimValue : queryVector.entrySet()) {
Expand All @@ -156,11 +156,11 @@ public VectorSparseFunctions(Map<String, Number> queryVector) {
} catch (final NumberFormatException e) {
throw new IllegalArgumentException("Failed to parse a query vector dimension, it must be an integer!", e);
}
queryValues[i] = dimValue.getValue().doubleValue();
queryValues[i] = dimValue.getValue().floatValue();
i++;
}
// Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions
sortSparseDimsDoubleValues(queryDims, queryValues, n);
sortSparseDimsFloatValues(queryDims, queryValues, n);
}
}

Expand Down Expand Up @@ -317,7 +317,7 @@ public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDoc
}
}

private static double intDotProductSparse(double[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) {
private static double intDotProductSparse(float[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) {
double v1v2DotProduct = 0;
int v1Index = 0;
int v2Index = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ public void testDenseVectorFunctions() {
BytesRef encodedDocVector = mockEncodeDenseVector(docVector);
VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
List<Number> queryVector = Arrays.asList(0.5, 111.3, -13.0, 14.8, -156.0);
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);

// test dotProduct
double result = dotProduct(queryVector, dvs);
assertEquals("dotProduct result is not equal to the expected value!", 65425.626, result, 0.001);
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);

// test cosineSimilarity
CosineSimilarity cosineSimilarity = new CosineSimilarity(queryVector);
Expand Down Expand Up @@ -91,7 +91,7 @@ public void testSparseVectorFunctions() {
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(queryVector);
double result = docProductSparse.dotProductSparse(dvs);
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.626, result, 0.001);
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);

// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector);
Expand Down Expand Up @@ -128,7 +128,7 @@ public void testSparseVectorMissingDimensions1() {
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(queryVector);
double result = docProductSparse.dotProductSparse(dvs);
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.626, result, 0.001);
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);

// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector);
Expand Down Expand Up @@ -165,7 +165,7 @@ public void testSparseVectorMissingDimensions2() {
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(queryVector);
double result = docProductSparse.dotProductSparse(dvs);
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.626, result, 0.001);
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);

// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector);
Expand Down

0 comments on commit 8d16c9b

Please sign in to comment.