Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use float instead of double for query vectors. #46004

Merged
merged 3 commits into from
Aug 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ setup:

- match: {hits.hits.1._id: "2"}
- gte: {hits.hits.1._score: 12.29}
- lte: {hits.hits.1._score: 12.30}
- lte: {hits.hits.1._score: 12.31}

- match: {hits.hits.2._id: "3"}
- gte: {hits.hits.2._score: 0.00}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ setup:

- match: {hits.hits.1._id: "2"}
- gte: {hits.hits.1._score: 12.29}
- lte: {hits.hits.1._score: 12.30}
- lte: {hits.hits.1._score: 12.31}

- match: {hits.hits.2._id: "3"}
- gte: {hits.hits.2._score: 0.00}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public void swap(int i, int j) {
* @param values - values for the sparse query vector
* @param n - number of dimensions
*/
public static void sortSparseDimsDoubleValues(int[] dims, double[] values, int n) {
public static void sortSparseDimsFloatValues(int[] dims, float[] values, int n) {
new InPlaceMergeSorter() {
@Override
public int compare(int i, int j) {
Expand All @@ -143,7 +143,7 @@ public void swap(int i, int j) {
dims[i] = dims[j];
dims[j] = tempDim;

double tempValue = values[j];
float tempValue = values[j];
values[j] = values[i];
values[i] = tempValue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import java.util.List;
import java.util.Map;

import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsDoubleValues;
import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsFloatValues;

public class ScoreScriptUtils {

Expand All @@ -37,7 +37,7 @@ public static double l1norm(List<Number> queryVector, VectorScriptDocValues.Dens
Iterator<Number> queryVectorIter = queryVector.iterator();
double l1norm = 0;
for (int dim = 0; dim < docVector.length; dim++){
l1norm += Math.abs(queryVectorIter.next().doubleValue() - docVector[dim]);
l1norm += Math.abs(queryVectorIter.next().floatValue() - docVector[dim]);
}
return l1norm;
}
Expand All @@ -59,7 +59,7 @@ public static double l2norm(List<Number> queryVector, VectorScriptDocValues.Dens
Iterator<Number> queryVectorIter = queryVector.iterator();
double l2norm = 0;
for (int dim = 0; dim < docVector.length; dim++){
double diff = queryVectorIter.next().doubleValue() - docVector[dim];
double diff = queryVectorIter.next().floatValue() - docVector[dim];
l2norm += diff * diff;
}
return Math.sqrt(l2norm);
Expand Down Expand Up @@ -97,11 +97,11 @@ public static final class CosineSimilarity {
// calculate queryVectorMagnitude once per query execution
public CosineSimilarity(List<Number> queryVector) {
this.queryVector = queryVector;
double doubleValue;

double dotProduct = 0;
for (Number value : queryVector) {
doubleValue = value.doubleValue();
dotProduct += doubleValue * doubleValue;
float floatValue = value.floatValue();
dotProduct += floatValue * floatValue;
}
this.queryVectorMagnitude = Math.sqrt(dotProduct);
}
Expand Down Expand Up @@ -130,7 +130,7 @@ private static double intDotProduct(List<Number> v1, float[] v2){
double v1v2DotProduct = 0;
Iterator<Number> v1Iter = v1.iterator();
for (int dim = 0; dim < v2.length; dim++) {
v1v2DotProduct += v1Iter.next().doubleValue() * v2[dim];
v1v2DotProduct += v1Iter.next().floatValue() * v2[dim];
}
return v1v2DotProduct;
}
Expand All @@ -139,15 +139,15 @@ private static double intDotProduct(List<Number> v1, float[] v2){
//**************FUNCTIONS FOR SPARSE VECTORS

public static class VectorSparseFunctions {
final double[] queryValues;
final float[] queryValues;
final int[] queryDims;

// prepare queryVector once per script execution
// queryVector represents a map of dimensions to values
public VectorSparseFunctions(Map<String, Number> queryVector) {
//break vector into two arrays dims and values
int n = queryVector.size();
queryValues = new double[n];
queryValues = new float[n];
queryDims = new int[n];
int i = 0;
for (Map.Entry<String, Number> dimValue : queryVector.entrySet()) {
Expand All @@ -156,11 +156,11 @@ public VectorSparseFunctions(Map<String, Number> queryVector) {
} catch (final NumberFormatException e) {
throw new IllegalArgumentException("Failed to parse a query vector dimension, it must be an integer!", e);
}
queryValues[i] = dimValue.getValue().doubleValue();
queryValues[i] = dimValue.getValue().floatValue();
i++;
}
// Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions
sortSparseDimsDoubleValues(queryDims, queryValues, n);
sortSparseDimsFloatValues(queryDims, queryValues, n);
}
}

Expand Down Expand Up @@ -317,7 +317,7 @@ public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDoc
}
}

private static double intDotProductSparse(double[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) {
private static double intDotProductSparse(float[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) {
double v1v2DotProduct = 0;
int v1Index = 0;
int v2Index = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ public void testDenseVectorFunctions() {
BytesRef encodedDocVector = mockEncodeDenseVector(docVector);
VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
List<Number> queryVector = Arrays.asList(0.5, 111.3, -13.0, 14.8, -156.0);
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);

// test dotProduct
double result = dotProduct(queryVector, dvs);
assertEquals("dotProduct result is not equal to the expected value!", 65425.626, result, 0.001);
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);

// test cosineSimilarity
CosineSimilarity cosineSimilarity = new CosineSimilarity(queryVector);
Expand Down Expand Up @@ -91,7 +91,7 @@ public void testSparseVectorFunctions() {
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(queryVector);
double result = docProductSparse.dotProductSparse(dvs);
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.626, result, 0.001);
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);

// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector);
Expand Down Expand Up @@ -128,7 +128,7 @@ public void testSparseVectorMissingDimensions1() {
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(queryVector);
double result = docProductSparse.dotProductSparse(dvs);
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.626, result, 0.001);
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);

// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector);
Expand Down Expand Up @@ -165,7 +165,7 @@ public void testSparseVectorMissingDimensions2() {
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(queryVector);
double result = docProductSparse.dotProductSparse(dvs);
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.626, result, 0.001);
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);

// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector);
Expand Down