Skip to content

Commit

Permalink
Use to VectorScorer for exact vector scoring (elastic#109945)
Browse files Browse the repository at this point in the history
Lucene 9.11 introduced a new VectorScorer interface. We should utilize
this interface when scoring exact vectors. 

related to: elastic#109293
  • Loading branch information
benwtrent authored Jun 20, 2024
1 parent c709b78 commit 3faf4ce
Show file tree
Hide file tree
Showing 7 changed files with 611 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,6 @@
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.FunctionQuery;
import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource;
import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSource;
import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource;
import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource;
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
Expand Down Expand Up @@ -67,6 +58,7 @@
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.support.CoreValuesSourceType;
import org.elasticsearch.search.vectors.DenseVectorQuery;
import org.elasticsearch.search.vectors.ESDiversifyingChildrenByteKnnVectorQuery;
import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery;
import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
Expand Down Expand Up @@ -1484,19 +1476,7 @@ private Query createExactKnnByteQuery(byte[] queryVector) {
float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
}
VectorSimilarityFunction vectorSimilarityFunction = similarity.vectorSimilarityFunction(indexVersionCreated, elementType);
return new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER)
.add(
new FunctionQuery(
new ByteVectorSimilarityFunction(
vectorSimilarityFunction,
new ByteKnnVectorFieldSource(name()),
new ConstKnnByteVectorValueSource(queryVector)
)
),
BooleanClause.Occur.SHOULD
)
.build();
return new DenseVectorQuery.Bytes(queryVector, name());
}

private Query createExactKnnFloatQuery(float[] queryVector) {
Expand All @@ -1519,19 +1499,7 @@ && isNotUnitVector(squaredMagnitude)) {
}
}
}
VectorSimilarityFunction vectorSimilarityFunction = similarity.vectorSimilarityFunction(indexVersionCreated, elementType);
return new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER)
.add(
new FunctionQuery(
new FloatVectorSimilarityFunction(
vectorSimilarityFunction,
new FloatKnnVectorFieldSource(name()),
new ConstKnnFloatValueSource(queryVector)
)
),
BooleanClause.Occur.SHOULD
)
.build();
return new DenseVectorQuery.Floats(queryVector, name());
}

Query createKnnQuery(float[] queryVector, int numCands, Query filter, Float similarityThreshold, BitSetProducer parentFilter) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.search.vectors;

import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

/**
* Exact knn query. Will iterate and score all documents that have the provided dense vector field in the index.
*/
public abstract class DenseVectorQuery extends Query {

protected final String field;

public DenseVectorQuery(String field) {
this.field = field;
}

@Override
public void visit(QueryVisitor queryVisitor) {
queryVisitor.visitLeaf(this);
}

abstract static class DenseVectorWeight extends Weight {
private final String field;
private final float boost;

protected DenseVectorWeight(DenseVectorQuery query, float boost) {
super(query);
this.field = query.field;
this.boost = boost;
}

abstract VectorScorer vectorScorer(LeafReaderContext leafReaderContext) throws IOException;

@Override
public Explanation explain(LeafReaderContext leafReaderContext, int i) throws IOException {
VectorScorer vectorScorer = vectorScorer(leafReaderContext);
if (vectorScorer == null) {
return Explanation.noMatch("No vector values found for field: " + field);
}
DocIdSetIterator iterator = vectorScorer.iterator();
iterator.advance(i);
if (iterator.docID() == i) {
float score = vectorScorer.score();
return Explanation.match(vectorScorer.score() * boost, "found vector with calculated similarity: " + score);
}
return Explanation.noMatch("Document not found in vector values for field: " + field);
}

@Override
public Scorer scorer(LeafReaderContext leafReaderContext) throws IOException {
VectorScorer vectorScorer = vectorScorer(leafReaderContext);
if (vectorScorer == null) {
return null;
}
return new DenseVectorScorer(this, vectorScorer);
}

@Override
public boolean isCacheable(LeafReaderContext leafReaderContext) {
return true;
}
}

public static class Floats extends DenseVectorQuery {

private final float[] query;

public Floats(float[] query, String field) {
super(field);
this.query = query;
}

public float[] getQuery() {
return query;
}

@Override
public String toString(String field) {
return "DenseVectorQuery.Floats";
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new DenseVectorWeight(Floats.this, boost) {
@Override
VectorScorer vectorScorer(LeafReaderContext leafReaderContext) throws IOException {
FloatVectorValues vectorValues = leafReaderContext.reader().getFloatVectorValues(field);
if (vectorValues == null) {
return null;
}
return vectorValues.scorer(query);
}
};
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Floats floats = (Floats) o;
return Objects.equals(field, floats.field) && Objects.deepEquals(query, floats.query);
}

@Override
public int hashCode() {
return Objects.hash(field, Arrays.hashCode(query));
}
}

public static class Bytes extends DenseVectorQuery {

private final byte[] query;

public Bytes(byte[] query, String field) {
super(field);
this.query = query;
}

@Override
public String toString(String field) {
return "DenseVectorQuery.Bytes";
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new DenseVectorWeight(Bytes.this, boost) {
@Override
VectorScorer vectorScorer(LeafReaderContext leafReaderContext) throws IOException {
ByteVectorValues vectorValues = leafReaderContext.reader().getByteVectorValues(field);
if (vectorValues == null) {
return null;
}
return vectorValues.scorer(query);
}
};
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Bytes bytes = (Bytes) o;
return Objects.equals(field, bytes.field) && Objects.deepEquals(query, bytes.query);
}

@Override
public int hashCode() {
return Objects.hash(field, Arrays.hashCode(query));
}
}

static class DenseVectorScorer extends Scorer {

private final VectorScorer vectorScorer;
private final DocIdSetIterator iterator;
private final float boost;

DenseVectorScorer(DenseVectorWeight weight, VectorScorer vectorScorer) {
super(weight);
this.vectorScorer = vectorScorer;
this.iterator = vectorScorer.iterator();
this.boost = weight.boost;
}

@Override
public DocIdSetIterator iterator() {
return vectorScorer.iterator();
}

@Override
public float getMaxScore(int i) throws IOException {
// TODO: can we optimize this at all?
return Float.POSITIVE_INFINITY;
}

@Override
public float score() throws IOException {
assert iterator.docID() != -1;
return vectorScorer.score() * boost;
}

@Override
public int docID() {
return iterator.docID();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@

package org.elasticsearch.index.mapper.vectors;

import org.apache.lucene.queries.function.FunctionQuery;
import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
Expand All @@ -25,6 +20,7 @@
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity;
import org.elasticsearch.search.vectors.DenseVectorQuery;
import org.elasticsearch.search.vectors.VectorData;

import java.io.IOException;
Expand Down Expand Up @@ -218,16 +214,7 @@ public void testExactKnnQuery() {
queryVector[i] = randomFloat();
}
Query query = field.createExactKnnQuery(VectorData.fromFloats(queryVector));
assertTrue(query instanceof BooleanQuery);
BooleanQuery booleanQuery = (BooleanQuery) query;
boolean foundFunction = false;
for (BooleanClause clause : booleanQuery) {
if (clause.getQuery() instanceof FunctionQuery functionQuery) {
foundFunction = true;
assertTrue(functionQuery.getValueSource() instanceof FloatVectorSimilarityFunction);
}
}
assertTrue("Unable to find FloatVectorSimilarityFunction in created BooleanQuery", foundFunction);
assertTrue(query instanceof DenseVectorQuery.Floats);
}
{
DenseVectorFieldType field = new DenseVectorFieldType(
Expand All @@ -245,16 +232,7 @@ public void testExactKnnQuery() {
queryVector[i] = randomByte();
}
Query query = field.createExactKnnQuery(VectorData.fromBytes(queryVector));
assertTrue(query instanceof BooleanQuery);
BooleanQuery booleanQuery = (BooleanQuery) query;
boolean foundFunction = false;
for (BooleanClause clause : booleanQuery) {
if (clause.getQuery() instanceof FunctionQuery functionQuery) {
foundFunction = true;
assertTrue(functionQuery.getValueSource() instanceof ByteVectorSimilarityFunction);
}
}
assertTrue("Unable to find FloatVectorSimilarityFunction in created BooleanQuery", foundFunction);
assertTrue(query instanceof DenseVectorQuery.Bytes);
}
}

Expand Down
Loading

0 comments on commit 3faf4ce

Please sign in to comment.