From ddeb38d6819601922fb9bbe9f49df31f4ea5c4a7 Mon Sep 17 00:00:00 2001 From: Dooyong Kim Date: Tue, 31 Dec 2024 23:11:50 -0800 Subject: [PATCH] Use Lucene HNSW searcher to search FAISS --- .../knn/index/query/KNNQueryFactory.java | 3 +- .../opensearch/knn/index/query/KNNWeight.java | 92 +++++++++++-- .../nativelib/NativeEngineKnnVectorQuery.java | 124 ++++++++++++++++++ .../partial_loading/DistanceMaxHeap.java | 2 +- .../partial_loading/KdyFaissHnswGraph.java | 59 +++++++++ 5 files changed, 269 insertions(+), 11 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/store/partial_loading/KdyFaissHnswGraph.java diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index dab2e08c8..f156ec6b5 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -95,7 +95,8 @@ public static Query create(CreateQueryRequest createQueryRequest) { .rescoreContext(rescoreContext) .build(); } - return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery; + // return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery; + return new NativeEngineKnnVectorQuery(knnQuery); } Integer requestEfSearch = null; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 485db29e8..27ac60000 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -13,10 +13,13 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.FilteredDocIdSetIterator; +import org.apache.lucene.util.hnsw.BlockingFloatHeap; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.knn.MultiLeafKnnCollector; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.BitSet; +import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; @@ -80,7 +83,7 @@ private static ExactSearcher DEFAULT_EXACT_SEARCHER; private final QuantizationService quantizationService; - private KdyMaxScoreTracker maxScoreTracker; + private MultiLeafKnnCollector collector; public KNNWeight(KNNQuery query, float boost) { super(query); @@ -90,7 +93,7 @@ public KNNWeight(KNNQuery query, float boost) { this.filterWeight = null; this.exactSearcher = DEFAULT_EXACT_SEARCHER; this.quantizationService = QuantizationService.getInstance(); - this.maxScoreTracker = new KdyMaxScoreTracker(knnQuery.getK()); + this.collector = new MultiLeafKnnCollector(100, new BlockingFloatHeap(100), new TopKnnCollector(100, Integer.MAX_VALUE)); } public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { @@ -101,7 +104,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { this.filterWeight = filterWeight; this.exactSearcher = DEFAULT_EXACT_SEARCHER; this.quantizationService = QuantizationService.getInstance(); - this.maxScoreTracker = new KdyMaxScoreTracker(knnQuery.getK()); + this.collector = new MultiLeafKnnCollector(100, new BlockingFloatHeap(100), new TopKnnCollector(100, Integer.MAX_VALUE)); } public static void initialize(ModelDao modelDao) { @@ -232,9 +235,80 @@ private Map doExactSearch(final LeafReaderContext context, final return exactSearch(context, exactSearcherContextBuilder.build()); } + public NativeMemoryAllocation kdyGetNativeMemoryAllocation(final LeafReaderContext context) throws IOException { + final SegmentReader reader = Lucene.segmentReader(context.reader()); + + FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + KNNEngine knnEngine; + SpaceType spaceType; + VectorDataType vectorDataType; + + // Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's + // metadata. + String modelId = fieldInfo.getAttribute(MODEL_ID); + if (modelId != null) { + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); + if (!ModelUtil.isModelCreated(modelMetadata)) { + throw new RuntimeException("Model \"" + modelId + "\" is not created."); + } + + knnEngine = modelMetadata.getKnnEngine(); + spaceType = modelMetadata.getSpaceType(); + vectorDataType = modelMetadata.getVectorDataType(); + } else { + String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.DEFAULT.getName()); + knnEngine = KNNEngine.getEngine(engineName); + String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); + spaceType = SpaceType.getSpace(spaceTypeName); + vectorDataType = + VectorDataType.get(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())); + } + + final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo = + SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField()); + // TODO: Change type of vector once more quantization methods are supported + final byte[] quantizedVector = SegmentLevelQuantizationUtil.quantizeVector(knnQuery.getQueryVector(), segmentLevelQuantizationInfo); + + List engineFiles = KNNCodecUtil.getEngineFiles(knnEngine.getExtension(), knnQuery.getField(), reader.getSegmentInfo().info); + if (engineFiles.isEmpty()) { + log.debug("[KNN] No native engine files found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName()); + return null; + } + + final String vectorIndexFileName = engineFiles.get(0); + final String cacheKey = NativeMemoryCacheKeyHelper.constructCacheKey(vectorIndexFileName, reader.getSegmentInfo().info); + + final KNNQueryResult[] results; + KNNCounter.GRAPH_QUERY_REQUESTS.increment(); + + // We need to first get index allocation + NativeMemoryAllocation indexAllocation; + try { + indexAllocation = nativeMemoryCacheManager.get(new NativeMemoryEntryContext.IndexEntryContext(reader.directory(), + cacheKey, + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), + getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName(), + // TODO: In the future, more vector data types will be supported with quantization + quantizedVector == null ? vectorDataType : VectorDataType.BINARY + ), + knnQuery.getIndexName(), + modelId + ), true); + } catch (ExecutionException e) { + GRAPH_QUERY_ERRORS.increment(); + throw new RuntimeException(e); + } + + return indexAllocation; + } + private Map doANNSearch( final LeafReaderContext context, final BitSet filterIdsBitSet, final int cardinality, final int k ) throws IOException { + if (true) { + throw new RuntimeException("+++++++++++++++++"); + } + final SegmentReader reader = Lucene.segmentReader(context.reader()); FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); @@ -335,19 +409,19 @@ private Map doANNSearch( // Java based partial loading if (KdyControl.DO_JAVA) { - final float minEligibleMaxDistance = maxScoreTracker.distanceMaxHeap.size() < k - ? Float.MAX_VALUE : maxScoreTracker.distanceMaxHeap.top().distance; - // System.out.println("********* minEligibleMaxDistance=" + minEligibleMaxDistance); - // System.out.println("********* maxScoreTracker.distanceMaxHeap.size() =" + maxScoreTracker.distanceMaxHeap.size()); + final float maxEligibleMaxDistance = -collector.minCompetitiveSimilarity(); + // System.out.println("********* maxEligibleMaxDistance=" + maxEligibleMaxDistance); + // System.out.println("********* maxScoreTracker.distanceMaxHeap.size()=" + // + maxScoreTracker.distanceMaxHeap.size()); results = kdySearch(indexAllocation.getPartialLoadingContext().kdyHNSW, indexAllocation.getPartialLoadingContext().indexInputThreadLocalGetter.getIndexInputWithBuffer().indexInput, knnQuery.getQueryVector(), k, - minEligibleMaxDistance + maxEligibleMaxDistance ); for (KNNQueryResult knnResult : results) { - maxScoreTracker.distanceMaxHeap.insertWithOverflow(0, knnResult.getScore()); + collector.collect(0, -knnResult.getScore()); } } else { // C++ based partial loading diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index a34a0f1ee..75d61b95a 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -9,23 +9,38 @@ import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.TopKnnCollectorManager; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.HnswGraphSearcher; +import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; +import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.opensearch.common.StopWatch; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.query.ExactSearcher; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.query.ResultUtil; import org.opensearch.knn.index.query.rescore.RescoreContext; +import org.opensearch.knn.index.store.partial_loading.FlatL2DistanceComputer; +import org.opensearch.knn.index.store.partial_loading.KdyFaissHnswGraph; +import org.opensearch.knn.index.store.partial_loading.KdyHNSW; import java.io.IOException; import java.util.ArrayList; @@ -36,6 +51,7 @@ import java.util.Objects; import java.util.concurrent.Callable; + /** * {@link KNNQuery} executes approximate nearest neighbor search (ANN) on a segment level. * {@link NativeEngineKnnVectorQuery} executes approximate nearest neighbor search but gives @@ -50,8 +66,115 @@ public class NativeEngineKnnVectorQuery extends Query { private final KNNQuery knnQuery; + public Weight kdyPartialLoading(IndexSearcher indexSearcher, ScoreMode scoreMode, float boost) throws IOException { + System.out.println("kdyPartialLoading!!!!!!!!!!!!"); + + KnnCollectorManager knnCollectorManager = new TopKnnCollectorManager(100, indexSearcher); + IndexReader reader = indexSearcher.getIndexReader(); + final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1); + List leafReaderContexts = reader.leaves(); + List> tasks = new ArrayList<>(leafReaderContexts.size()); + for (LeafReaderContext context : leafReaderContexts) { + tasks.add(() -> kdySearchLeaf(knnQuery.getQueryVector(), context, null, knnCollectorManager, knnWeight)); + } + TaskExecutor taskExecutor = indexSearcher.getTaskExecutor(); + TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new); + + // Merge sort the results + TopDocs topK = TopDocs.merge(10, perLeafResults); + if (topK.scoreDocs.length == 0) { + return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost); + } + return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); + } + + private TopDocs kdySearchLeaf( + float[] queryVector, + LeafReaderContext ctx, + Weight filterWeight, + KnnCollectorManager knnCollectorManager, + KNNWeight knnWeight) + throws IOException { + TopDocs results = kdyGetLeafResults(queryVector, ctx, filterWeight, knnCollectorManager, knnWeight); + if (ctx.docBase > 0) { + for (ScoreDoc scoreDoc : results.scoreDocs) { + scoreDoc.doc += ctx.docBase; + } + } + return results; + } + + private TopDocs kdyGetLeafResults( + float[] queryVector, + LeafReaderContext ctx, + Weight filterWeight, + KnnCollectorManager knnCollectorManager, + KNNWeight knnWeight) + throws IOException { + final LeafReader reader = ctx.reader(); + final Bits liveDocs = reader.getLiveDocs(); + + if (filterWeight == null) { + return kdyApproximateSearch(queryVector, ctx, liveDocs, Integer.MAX_VALUE, knnCollectorManager, knnWeight); + } + + throw new RuntimeException("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX filterWeight != null???"); + } + + private TopDocs kdyApproximateSearch( + float[] queryVector, + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager, + KNNWeight knnWeight) + throws IOException { + KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); + + NativeMemoryAllocation nativeMemoryAllocation = knnWeight.kdyGetNativeMemoryAllocation(context); + if (nativeMemoryAllocation == null) { + return TopDocsCollector.EMPTY_TOPDOCS; + } + + KdyHNSW kdyHNSW = nativeMemoryAllocation.getPartialLoadingContext().kdyHNSW; + IndexInput indexInput = + nativeMemoryAllocation.getPartialLoadingContext().indexInputThreadLocalGetter.getIndexInputWithBuffer().indexInput; + IndexInput vectorIndexInput = indexInput.clone(); + + FlatL2DistanceComputer l2Computer = + new FlatL2DistanceComputer(queryVector, kdyHNSW.indexFlatL2.codes, kdyHNSW.indexFlatL2.oneVectorByteSize); + + RandomVectorScorer scorer = new RandomVectorScorer() { + @Override public float score(int node) throws IOException { + return l2Computer.compute(vectorIndexInput, node); + } + + @Override public int maxOrd() { + return Integer.MAX_VALUE; + } + }; + + final KnnCollector collector = + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc); + final Bits acceptedOrds = null; + KdyFaissHnswGraph kdyFaissHnswGraph = new KdyFaissHnswGraph(kdyHNSW, indexInput.clone()); + HnswGraphSearcher.search(scorer, collector, kdyFaissHnswGraph, acceptedOrds); + TopDocs results = knnCollector.topDocs(); + return results; + } + @Override public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, float boost) throws IOException { + System.out.println("++++++++++++++++++++++++++++"); + return kdyPartialLoading(indexSearcher, scoreMode, boost); + + + + + + + + /* final IndexReader reader = indexSearcher.getIndexReader(); final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, scoreMode, 1); List leafReaderContexts = reader.leaves(); @@ -85,6 +208,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost); } return createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); + */ } private List> doSearch( diff --git a/src/main/java/org/opensearch/knn/index/store/partial_loading/DistanceMaxHeap.java b/src/main/java/org/opensearch/knn/index/store/partial_loading/DistanceMaxHeap.java index 72372c02b..b67f53367 100644 --- a/src/main/java/org/opensearch/knn/index/store/partial_loading/DistanceMaxHeap.java +++ b/src/main/java/org/opensearch/knn/index/store/partial_loading/DistanceMaxHeap.java @@ -138,7 +138,7 @@ private void upHeap(int origPos) { private void downHeap(int i) { FaissHNSW.IdAndDistance node = heap[i]; // save top node - int j = i << 1; // find smaller child + int j = i << 1; // find bigger child int k = (i << 1) + 1; if (k <= this.k && greaterThan(heap[k].distance, heap[j].distance)) { j = k; diff --git a/src/main/java/org/opensearch/knn/index/store/partial_loading/KdyFaissHnswGraph.java b/src/main/java/org/opensearch/knn/index/store/partial_loading/KdyFaissHnswGraph.java new file mode 100644 index 000000000..59c84ccec --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/store/partial_loading/KdyFaissHnswGraph.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.store.partial_loading; + +import java.io.IOException; +import org.apache.lucene.util.hnsw.HnswGraph; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.hnsw.HnswGraph; + +public class KdyFaissHnswGraph extends HnswGraph { + public FaissHNSW hnsw; + private long begin; + private long end; + private long currOffset; + private IndexInput indexInput; + + public KdyFaissHnswGraph(KdyHNSW kdyHNSW, IndexInput indexInput) { + this.hnsw = kdyHNSW.hnswFlatIndex.hnsw; + this.indexInput = indexInput; + } + + @Override public void seek(int level, int target) throws IOException { + long o = hnsw.offsets[target]; + begin = o + hnsw.cumNumberNeighborPerLevel[level]; + end = o + hnsw.cumNumberNeighborPerLevel[level + 1]; + currOffset = begin; + } + + @Override public int size() { + return 1000_0000; + } + + @Override public int nextNeighbor() throws IOException { + if (currOffset < end) { + final int id = hnsw.neighbors.readInt(indexInput, currOffset); + ++currOffset; + if (id >= 0) { + return id; + } + } + + return Integer.MAX_VALUE; + } + + @Override public int numLevels() throws IOException { + return hnsw.maxLevel; + } + + @Override public int entryNode() throws IOException { + return hnsw.entryPoint; + } + + @Override public HnswGraph.NodesIterator getNodesOnLevel(int level) throws IOException { + throw new UnsupportedOperationException(); + } +}