Skip to content

Commit

Permalink
LUCENE-10063: implement SimpleTextKnnvectorsReader.search
Browse files Browse the repository at this point in the history
  • Loading branch information
msokolov authored Aug 31, 2021
1 parent 6ade29c commit 9c7f0d4
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.BufferedChecksumIndexInput;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
Expand Down Expand Up @@ -140,7 +145,33 @@ public VectorValues getVectorValues(String field) throws IOException {

@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
throw new UnsupportedOperationException();
VectorValues values = getVectorValues(field);
if (target.length != values.dimension()) {
throw new IllegalArgumentException(
"incorrect dimension for field "
+ field
+ "; expected "
+ values.dimension()
+ " but target has "
+ target.length);
}
FieldInfo info = readState.fieldInfos.fieldInfo(field);
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
HitQueue topK = new HitQueue(k, false);
int doc;
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
float[] vector = values.vectorValue();
float score = vectorSimilarity.compare(vector, target);
if (vectorSimilarity.reversed) {
score = 1 / (score + 1);
}
topK.insertWithOverflow(new ScoreDoc(doc, score));
}
ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = topK.pop();
}
return new TopDocs(new TotalHits(values.size(), TotalHits.Relation.EQUAL_TO), topScoreDocs);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@
import org.apache.lucene.index.Term;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.LuceneTestCase.SuppressCodecs;

/** TestKnnVectorQuery tests KnnVectorQuery. */
@SuppressCodecs("SimpleText") // The codec must support kNN searches
public class TestKnnVectorQuery extends LuceneTestCase {

public void testEquals() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.lucene.document.StringField;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.Bits;
Expand Down Expand Up @@ -856,6 +857,15 @@ public void testRandomWithUpdatesAndGraph() throws Exception {
}
}
}
// assert that searchNearestVectors returns the expected number of documents, in
// descending score order
int k = random().nextInt(numDoc / 2);
TopDocs results =
ctx.reader().searchNearestVectors(fieldName, randomVector(dimension), k, liveDocs);
assertEquals(k, results.scoreDocs.length);
for (int i = 0; i < k - 1; i++) {
assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score);
}
}
}
}
Expand Down

0 comments on commit 9c7f0d4

Please sign in to comment.