Skip to content

Commit

Permalink
Use Lucene HNSW searcher to search FAISS
Browse files Browse the repository at this point in the history
  • Loading branch information
Dooyong Kim committed Jan 1, 2025
1 parent f13e9f4 commit ddeb38d
Show file tree
Hide file tree
Showing 5 changed files with 269 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.rescoreContext(rescoreContext)
.build();
}
return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery;
// return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery;
return new NativeEngineKnnVectorQuery(knnQuery);
}

Integer requestEfSearch = null;
Expand Down
92 changes: 83 additions & 9 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.util.hnsw.BlockingFloatHeap;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.knn.MultiLeafKnnCollector;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
Expand Down Expand Up @@ -80,7 +83,7 @@

private static ExactSearcher DEFAULT_EXACT_SEARCHER;
private final QuantizationService quantizationService;
private KdyMaxScoreTracker maxScoreTracker;
private MultiLeafKnnCollector collector;

public KNNWeight(KNNQuery query, float boost) {
super(query);
Expand All @@ -90,7 +93,7 @@ public KNNWeight(KNNQuery query, float boost) {
this.filterWeight = null;
this.exactSearcher = DEFAULT_EXACT_SEARCHER;
this.quantizationService = QuantizationService.getInstance();
this.maxScoreTracker = new KdyMaxScoreTracker(knnQuery.getK());
this.collector = new MultiLeafKnnCollector(100, new BlockingFloatHeap(100), new TopKnnCollector(100, Integer.MAX_VALUE));
}

public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
Expand All @@ -101,7 +104,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
this.filterWeight = filterWeight;
this.exactSearcher = DEFAULT_EXACT_SEARCHER;
this.quantizationService = QuantizationService.getInstance();
this.maxScoreTracker = new KdyMaxScoreTracker(knnQuery.getK());
this.collector = new MultiLeafKnnCollector(100, new BlockingFloatHeap(100), new TopKnnCollector(100, Integer.MAX_VALUE));
}

public static void initialize(ModelDao modelDao) {
Expand Down Expand Up @@ -232,9 +235,80 @@ private Map<Integer, Float> doExactSearch(final LeafReaderContext context, final
return exactSearch(context, exactSearcherContextBuilder.build());
}

public NativeMemoryAllocation kdyGetNativeMemoryAllocation(final LeafReaderContext context) throws IOException {
final SegmentReader reader = Lucene.segmentReader(context.reader());

FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
KNNEngine knnEngine;
SpaceType spaceType;
VectorDataType vectorDataType;

// Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's
// metadata.
String modelId = fieldInfo.getAttribute(MODEL_ID);
if (modelId != null) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new RuntimeException("Model \"" + modelId + "\" is not created.");
}

knnEngine = modelMetadata.getKnnEngine();
spaceType = modelMetadata.getSpaceType();
vectorDataType = modelMetadata.getVectorDataType();
} else {
String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.DEFAULT.getName());
knnEngine = KNNEngine.getEngine(engineName);
String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue());
spaceType = SpaceType.getSpace(spaceTypeName);
vectorDataType =
VectorDataType.get(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()));
}

final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo =
SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField());
// TODO: Change type of vector once more quantization methods are supported
final byte[] quantizedVector = SegmentLevelQuantizationUtil.quantizeVector(knnQuery.getQueryVector(), segmentLevelQuantizationInfo);

List<String> engineFiles = KNNCodecUtil.getEngineFiles(knnEngine.getExtension(), knnQuery.getField(), reader.getSegmentInfo().info);
if (engineFiles.isEmpty()) {
log.debug("[KNN] No native engine files found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName());
return null;
}

final String vectorIndexFileName = engineFiles.get(0);
final String cacheKey = NativeMemoryCacheKeyHelper.constructCacheKey(vectorIndexFileName, reader.getSegmentInfo().info);

final KNNQueryResult[] results;
KNNCounter.GRAPH_QUERY_REQUESTS.increment();

// We need to first get index allocation
NativeMemoryAllocation indexAllocation;
try {
indexAllocation = nativeMemoryCacheManager.get(new NativeMemoryEntryContext.IndexEntryContext(reader.directory(),
cacheKey,
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(),
getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName(),
// TODO: In the future, more vector data types will be supported with quantization
quantizedVector == null ? vectorDataType : VectorDataType.BINARY
),
knnQuery.getIndexName(),
modelId
), true);
} catch (ExecutionException e) {
GRAPH_QUERY_ERRORS.increment();
throw new RuntimeException(e);
}

return indexAllocation;
}

private Map<Integer, Float> doANNSearch(
final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k
) throws IOException {
if (true) {
throw new RuntimeException("+++++++++++++++++");
}

final SegmentReader reader = Lucene.segmentReader(context.reader());

FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
Expand Down Expand Up @@ -335,19 +409,19 @@ private Map<Integer, Float> doANNSearch(

// Java based partial loading
if (KdyControl.DO_JAVA) {
final float minEligibleMaxDistance = maxScoreTracker.distanceMaxHeap.size() < k
? Float.MAX_VALUE : maxScoreTracker.distanceMaxHeap.top().distance;
// System.out.println("********* minEligibleMaxDistance=" + minEligibleMaxDistance);
// System.out.println("********* maxScoreTracker.distanceMaxHeap.size() =" + maxScoreTracker.distanceMaxHeap.size());
final float maxEligibleMaxDistance = -collector.minCompetitiveSimilarity();
// System.out.println("********* maxEligibleMaxDistance=" + maxEligibleMaxDistance);
// System.out.println("********* maxScoreTracker.distanceMaxHeap.size()="
// + maxScoreTracker.distanceMaxHeap.size());

results = kdySearch(indexAllocation.getPartialLoadingContext().kdyHNSW,
indexAllocation.getPartialLoadingContext().indexInputThreadLocalGetter.getIndexInputWithBuffer().indexInput,
knnQuery.getQueryVector(),
k,
minEligibleMaxDistance
maxEligibleMaxDistance
);
for (KNNQueryResult knnResult : results) {
maxScoreTracker.distanceMaxHeap.insertWithOverflow(0, knnResult.getScore());
collector.collect(0, -knnResult.getScore());
}
} else {
// C++ based partial loading
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,38 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.TopKnnCollectorManager;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.index.query.ExactSearcher;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.index.query.ResultUtil;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.store.partial_loading.FlatL2DistanceComputer;
import org.opensearch.knn.index.store.partial_loading.KdyFaissHnswGraph;
import org.opensearch.knn.index.store.partial_loading.KdyHNSW;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -36,6 +51,7 @@
import java.util.Objects;
import java.util.concurrent.Callable;


/**
* {@link KNNQuery} executes approximate nearest neighbor search (ANN) on a segment level.
* {@link NativeEngineKnnVectorQuery} executes approximate nearest neighbor search but gives
Expand All @@ -50,8 +66,115 @@ public class NativeEngineKnnVectorQuery extends Query {

private final KNNQuery knnQuery;

public Weight kdyPartialLoading(IndexSearcher indexSearcher, ScoreMode scoreMode, float boost) throws IOException {
System.out.println("kdyPartialLoading!!!!!!!!!!!!");

KnnCollectorManager knnCollectorManager = new TopKnnCollectorManager(100, indexSearcher);
IndexReader reader = indexSearcher.getIndexReader();
final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1);
List<LeafReaderContext> leafReaderContexts = reader.leaves();
List<Callable<TopDocs>> tasks = new ArrayList<>(leafReaderContexts.size());
for (LeafReaderContext context : leafReaderContexts) {
tasks.add(() -> kdySearchLeaf(knnQuery.getQueryVector(), context, null, knnCollectorManager, knnWeight));
}
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);

// Merge sort the results
TopDocs topK = TopDocs.merge(10, perLeafResults);
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost);
}
return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost);
}

private TopDocs kdySearchLeaf(
float[] queryVector,
LeafReaderContext ctx,
Weight filterWeight,
KnnCollectorManager knnCollectorManager,
KNNWeight knnWeight)
throws IOException {
TopDocs results = kdyGetLeafResults(queryVector, ctx, filterWeight, knnCollectorManager, knnWeight);
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
return results;
}

private TopDocs kdyGetLeafResults(
float[] queryVector,
LeafReaderContext ctx,
Weight filterWeight,
KnnCollectorManager knnCollectorManager,
KNNWeight knnWeight)
throws IOException {
final LeafReader reader = ctx.reader();
final Bits liveDocs = reader.getLiveDocs();

if (filterWeight == null) {
return kdyApproximateSearch(queryVector, ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager, knnWeight);
}

throw new RuntimeException("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX filterWeight != null???");
}

private TopDocs kdyApproximateSearch(
float[] queryVector,
LeafReaderContext context,
Bits acceptDocs,
int visitedLimit,
KnnCollectorManager knnCollectorManager,
KNNWeight knnWeight)
throws IOException {
KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context);

NativeMemoryAllocation nativeMemoryAllocation = knnWeight.kdyGetNativeMemoryAllocation(context);
if (nativeMemoryAllocation == null) {
return TopDocsCollector.EMPTY_TOPDOCS;
}

KdyHNSW kdyHNSW = nativeMemoryAllocation.getPartialLoadingContext().kdyHNSW;
IndexInput indexInput =
nativeMemoryAllocation.getPartialLoadingContext().indexInputThreadLocalGetter.getIndexInputWithBuffer().indexInput;
IndexInput vectorIndexInput = indexInput.clone();

FlatL2DistanceComputer l2Computer =
new FlatL2DistanceComputer(queryVector, kdyHNSW.indexFlatL2.codes, kdyHNSW.indexFlatL2.oneVectorByteSize);

RandomVectorScorer scorer = new RandomVectorScorer() {
@Override public float score(int node) throws IOException {
return l2Computer.compute(vectorIndexInput, node);
}

@Override public int maxOrd() {
return Integer.MAX_VALUE;
}
};

final KnnCollector collector =
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc);
final Bits acceptedOrds = null;
KdyFaissHnswGraph kdyFaissHnswGraph = new KdyFaissHnswGraph(kdyHNSW, indexInput.clone());
HnswGraphSearcher.search(scorer, collector, kdyFaissHnswGraph, acceptedOrds);
TopDocs results = knnCollector.topDocs();
return results;
}

@Override
public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, float boost) throws IOException {
System.out.println("++++++++++++++++++++++++++++");
return kdyPartialLoading(indexSearcher, scoreMode, boost);







/*
final IndexReader reader = indexSearcher.getIndexReader();
final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1);
List<LeafReaderContext> leafReaderContexts = reader.leaves();
Expand Down Expand Up @@ -85,6 +208,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost);
}
return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost);
*/
}

private List<Map<Integer, Float>> doSearch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ private void upHeap(int origPos) {

private void downHeap(int i) {
FaissHNSW.IdAndDistance node = heap[i]; // save top node
int j = i << 1; // find smaller child
int j = i << 1; // find bigger child
int k = (i << 1) + 1;
if (k <= this.k && greaterThan(heap[k].distance, heap[j].distance)) {
j = k;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.store.partial_loading;

import java.io.IOException;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.hnsw.HnswGraph;

public class KdyFaissHnswGraph extends HnswGraph {
public FaissHNSW hnsw;
private long begin;
private long end;
private long currOffset;
private IndexInput indexInput;

public KdyFaissHnswGraph(KdyHNSW kdyHNSW, IndexInput indexInput) {
this.hnsw = kdyHNSW.hnswFlatIndex.hnsw;
this.indexInput = indexInput;
}

@Override public void seek(int level, int target) throws IOException {
long o = hnsw.offsets[target];
begin = o + hnsw.cumNumberNeighborPerLevel[level];
end = o + hnsw.cumNumberNeighborPerLevel[level + 1];
currOffset = begin;
}

@Override public int size() {
return 1000_0000;
}

@Override public int nextNeighbor() throws IOException {
if (currOffset < end) {
final int id = hnsw.neighbors.readInt(indexInput, currOffset);
++currOffset;
if (id >= 0) {
return id;
}
}

return Integer.MAX_VALUE;
}

@Override public int numLevels() throws IOException {
return hnsw.maxLevel;
}

@Override public int entryNode() throws IOException {
return hnsw.entryPoint;
}

@Override public HnswGraph.NodesIterator getNodesOnLevel(int level) throws IOException {
throw new UnsupportedOperationException();
}
}

0 comments on commit ddeb38d

Please sign in to comment.