Skip to content

Commit

Permalink
Change abstraction to wrap around KNN query
Browse files Browse the repository at this point in the history
  • Loading branch information
dungba88 committed Nov 21, 2024
1 parent 8d88cab commit b67637a
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery();
}
return createRewrittenQuery(reader, topK);
return createRewrittenQuery(reader, topK.scoreDocs);
}

private TopDocs searchLeaf(
Expand Down Expand Up @@ -255,18 +255,18 @@ protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
return TopDocs.merge(k, perLeafResults);
}

private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;
static Query createRewrittenQuery(IndexReader reader, ScoreDoc[] scoreDocs) {
int len = scoreDocs.length;

assert len > 0;
float maxScore = topK.scoreDocs[0].score;
float maxScore = scoreDocs[0].score;

Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
Arrays.sort(scoreDocs, Comparator.comparingInt(a -> a.doc));
int[] docs = new int[len];
float[] scores = new float[len];
for (int i = 0; i < len; i++) {
docs[i] = topK.scoreDocs[i].doc;
scores[i] = topK.scoreDocs[i].score;
docs[i] = scoreDocs[i].doc;
scores[i] = scoreDocs[i].score;
}
int[] segmentStarts = findSegmentStarts(reader.leaves(), docs);
return new DocAndScoreQuery(docs, scores, maxScore, segmentStarts, reader.getContext().id());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;

import static org.apache.lucene.search.AbstractKnnVectorQuery.createRewrittenQuery;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.VectorSimilarityFunction;

/**
* A wrapper of KnnFloatVectorQuery which does full-precision reranking.
*
* @lucene.experimental
*/
public class RerankKnnFloatVectorQuery extends Query {

private final int k;
private final float[] target;
private final KnnFloatVectorQuery query;

/**
* Execute the KnnFloatVectorQuery and re-rank using full-precision vectors
*
* @param query the KNN query to execute as initial phase
* @param target the target of the search
* @param k the number of documents to find
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public RerankKnnFloatVectorQuery(KnnFloatVectorQuery query, float[] target, int k) {
this.query = query;
this.target = target;
this.k = k;
}

@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
IndexReader reader = indexSearcher.getIndexReader();
Query rewritten = indexSearcher.rewrite(query);
Weight weight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
HitQueue queue = new HitQueue(k, false);
for (var leaf : reader.leaves()) {
Scorer scorer = weight.scorer(leaf);
if (scorer == null) {
continue;
}
FloatVectorValues floatVectorValues = leaf.reader().getFloatVectorValues(query.getField());
if (floatVectorValues == null) {
continue;
}
FieldInfo fi = leaf.reader().getFieldInfos().fieldInfo(query.getField());
VectorSimilarityFunction comparer = fi.getVectorSimilarityFunction();
DocIdSetIterator iterator = scorer.iterator();
while (iterator.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
int docId = iterator.docID();
float[] vectorValue = floatVectorValues.vectorValue(docId);
float score = comparer.compare(vectorValue, target);
queue.insertWithOverflow(new ScoreDoc(docId, score));
}
}
int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
for (ScoreDoc topDoc : queue) {
scoreDocs[i++] = topDoc;
}
return createRewrittenQuery(reader, scoreDocs);
}

@Override
public int hashCode() {
int result = Arrays.hashCode(target);
result = 31 * result + Objects.hash(query, k);
return result;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
RerankKnnFloatVectorQuery that = (RerankKnnFloatVectorQuery) o;
return Objects.equals(query, that.query) && k == that.k;
}

@Override
public void visit(QueryVisitor visitor) {
query.visit(visitor);
}

@Override
public String toString(String field) {
return getClass().getSimpleName() + ":" + query.toString(field) + "[" + k + "]";
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@

org.apache.lucene.codecs.TestMinimalCodec$MinimalCodec
org.apache.lucene.codecs.TestMinimalCodec$MinimalCompoundCodec
org.apache.lucene.search.TestTwoPhaseKnnVectorQuery$QuantizedCodec
org.apache.lucene.search.TestRerankKnnFloatVectorQuery$QuantizedCodec
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import org.junit.Before;
import org.junit.Test;

public class TestTwoPhaseKnnVectorQuery extends LuceneTestCase {
public class TestRerankKnnFloatVectorQuery extends LuceneTestCase {

private static final String FIELD = "vector";
public static final VectorSimilarityFunction VECTOR_SIMILARITY_FUNCTION =
Expand Down Expand Up @@ -85,8 +85,9 @@ public void testTwoPhaseKnnVectorQuery() throws Exception {
int k = 10;
double oversample = 1.0;

TwoPhaseKnnVectorQuery query =
new TwoPhaseKnnVectorQuery(FIELD, targetVector, k, oversample, null);
KnnFloatVectorQuery knnQuery =
new KnnFloatVectorQuery(FIELD, targetVector, k + (int) (k * oversample));
RerankKnnFloatVectorQuery query = new RerankKnnFloatVectorQuery(knnQuery, targetVector, k);
TopDocs topDocs = searcher.search(query, k);

// Step 3: Verify that TopDocs scores match similarity with unquantized vectors
Expand Down

0 comments on commit b67637a

Please sign in to comment.