Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Add support to build vector data structures greedily and perform exact search when there are no engine files #2201

Merged
merged 2 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Update Default Rescore Context based on Dimension [#2149](https://github.com/opensearch-project/k-NN/pull/2149)
* KNNIterators should support with and without filters [#2155](https://github.com/opensearch-project/k-NN/pull/2155)
* Adding Support to Enable/Disble Share level Rescoring and Update Oversampling Factor[#2172](https://github.com/opensearch-project/k-NN/pull/2172)
* Add support to build vector data structures greedily and perform exact search when there are no engine files [#1942](https://github.com/opensearch-project/k-NN/issues/1942)
### Bug Fixes
* Add DocValuesProducers for releasing memory when close index [#1946](https://github.com/opensearch-project/k-NN/pull/1946)
* KNN80DocValues should only be considered for BinaryDocValues fields [#2147](https://github.com/opensearch-project/k-NN/pull/2147)
Expand Down
20 changes: 20 additions & 0 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public class KNNSettings {
* Settings name
*/
public static final String KNN_SPACE_TYPE = "index.knn.space_type";
public static final String INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD = "index.knn.advanced.approximate_threshold";
public static final String KNN_ALGO_PARAM_M = "index.knn.algo_param.m";
public static final String KNN_ALGO_PARAM_EF_CONSTRUCTION = "index.knn.algo_param.ef_construction";
public static final String KNN_ALGO_PARAM_EF_SEARCH = "index.knn.algo_param.ef_search";
Expand Down Expand Up @@ -97,6 +98,9 @@ public class KNNSettings {
public static final boolean KNN_DEFAULT_FAISS_AVX2_DISABLED_VALUE = false;
public static final boolean KNN_DEFAULT_FAISS_AVX512_DISABLED_VALUE = false;
public static final String INDEX_KNN_DEFAULT_SPACE_TYPE = "l2";
public static final Integer INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_DEFAULT_VALUE = 0;
public static final Integer INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN = -1;
public static final Integer INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX = Integer.MAX_VALUE - 2;
public static final String INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY = "hamming";
public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_M = 16;
public static final Integer INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH = 100;
Expand Down Expand Up @@ -156,6 +160,21 @@ public class KNNSettings {
Setting.Property.Deprecated
);

/**
* build_vector_data_structure_threshold - This parameter determines when to build vector data structure for knn fields during indexing
* and merging. Setting -1 (min) will skip building graph, whereas on any other values, the graph will be built if
* number of live docs in segment is greater than this threshold. Since max number of documents in a segment can
* be Integer.MAX_VALUE - 1, this setting will allow threshold to be up to 1 less than max number of documents in a segment
*/
public static final Setting<Integer> INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_SETTING = Setting.intSetting(
INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD,
INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_DEFAULT_VALUE,
INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MIN,
INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD_MAX,
IndexScope,
Dynamic
);

/**
* M - the number of bi-directional links created for every new element during construction.
* Reasonable range for M is 2-100. Higher M work better on datasets with high intrinsic
Expand Down Expand Up @@ -486,6 +505,7 @@ private Setting<?> getSetting(String key) {
public List<Setting<?>> getSettings() {
List<Setting<?>> settings = Arrays.asList(
INDEX_KNN_SPACE_TYPE,
INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_SETTING,
INDEX_KNN_ALGO_PARAM_M_SETTING,
INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING,
INDEX_KNN_ALGO_PARAM_EF_SEARCH_SETTING,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.KNN990Codec.NativeEngines990KnnVectorsFormat;
import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams;
import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams;
Expand Down Expand Up @@ -129,7 +131,23 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
}

private NativeEngines990KnnVectorsFormat nativeEngineVectorsFormat() {
return new NativeEngines990KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()));
// mapperService is already checked for null or valid instance type at caller, hence we don't need
// addition isPresent check here.
int approximateThreshold = getApproximateThresholdValue();
return new NativeEngines990KnnVectorsFormat(
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()),
approximateThreshold
);
}

private int getApproximateThresholdValue() {
// This is private method and mapperService is already checked for null or valid instance type before this call
// at caller, hence we don't need additional isPresent check here.
final IndexSettings indexSettings = mapperService.get().getIndexSettings();
final Integer approximateThresholdValue = indexSettings.getValue(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_SETTING);
return approximateThresholdValue != null
? approximateThresholdValue
: KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_DEFAULT_VALUE;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.opensearch.knn.index.KNNSettings;

import java.io.IOException;

Expand All @@ -30,15 +31,20 @@ public class NativeEngines990KnnVectorsFormat extends KnnVectorsFormat {
/** The format for storing, reading, merging vectors on disk */
private static FlatVectorsFormat flatVectorsFormat;
private static final String FORMAT_NAME = "NativeEngines990KnnVectorsFormat";
private static int approximateThreshold;

public NativeEngines990KnnVectorsFormat() {
super(FORMAT_NAME);
flatVectorsFormat = new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
this(new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer()));
}

public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat) {
this(flatVectorsFormat, KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_DEFAULT_VALUE);
}

public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat lucene99FlatVectorsFormat) {
public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat flatVectorsFormat, int approximateThreshold) {
super(FORMAT_NAME);
flatVectorsFormat = lucene99FlatVectorsFormat;
NativeEngines990KnnVectorsFormat.flatVectorsFormat = flatVectorsFormat;
NativeEngines990KnnVectorsFormat.approximateThreshold = approximateThreshold;
}

/**
Expand All @@ -48,7 +54,7 @@ public NativeEngines990KnnVectorsFormat(final FlatVectorsFormat lucene99FlatVect
*/
@Override
public KnnVectorsWriter fieldsWriter(final SegmentWriteState state) throws IOException {
return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state));
return new NativeEngines990KnnVectorsWriter(state, flatVectorsFormat.fieldsWriter(state), approximateThreshold);
}

/**
Expand All @@ -63,6 +69,12 @@ public KnnVectorsReader fieldsReader(final SegmentReadState state) throws IOExce

@Override
public String toString() {
return "NativeEngines99KnnVectorsFormat(name=" + this.getClass().getSimpleName() + ", flatVectorsFormat=" + flatVectorsFormat + ")";
return "NativeEngines99KnnVectorsFormat(name="
+ this.getClass().getSimpleName()
+ ", flatVectorsFormat="
+ flatVectorsFormat
+ ", approximateThreshold="
+ approximateThreshold
+ ")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,16 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter {
private KNN990QuantizationStateWriter quantizationStateWriter;
private final List<NativeEngineFieldVectorsWriter<?>> fields = new ArrayList<>();
private boolean finished;
private final Integer approximateThreshold;

public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) {
public NativeEngines990KnnVectorsWriter(
SegmentWriteState segmentWriteState,
FlatVectorsWriter flatVectorsWriter,
Integer approximateThreshold
) {
this.segmentWriteState = segmentWriteState;
this.flatVectorsWriter = flatVectorsWriter;
this.approximateThreshold = approximateThreshold;
}

/**
Expand Down Expand Up @@ -98,6 +104,17 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
// Check only after quantization state writer finish writing its state, since it is required
// even if there are no graph files in segment, which will be later used by exact search
if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
log.info(
"Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush",
fieldInfo.name,
totalLiveDocs,
approximateThreshold
);
continue;
}
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

Expand Down Expand Up @@ -127,6 +144,17 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
}

final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
// Check only after quantization state writer finish writing its state, since it is required
// even if there are no graph files in segment, which will be later used by exact search
if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
log.info(
"Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge",
fieldInfo.name,
totalLiveDocs,
approximateThreshold
);
return;
}
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

Expand Down Expand Up @@ -257,4 +285,11 @@ private void initQuantizationStateWriterIfNecessary() throws IOException {
quantizationStateWriter.writeHeader(segmentWriteState);
}
}

private boolean shouldSkipBuildingVectorDataStructure(final long docCount) {
if (approximateThreshold < 0) {
return true;
}
return docCount < approximateThreshold;
}
}
55 changes: 49 additions & 6 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

package org.opensearch.knn.index.query;

import com.google.common.base.Predicates;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.NonNull;
import lombok.Value;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
Expand All @@ -21,6 +23,7 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator;
Expand All @@ -36,7 +39,9 @@

import java.io.IOException;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.function.Predicate;

@Log4j2
@AllArgsConstructor
Expand All @@ -55,11 +60,41 @@ public class ExactSearcher {
public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext)
throws IOException {
KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
if (exactSearcherContext.getKnnQuery().getRadius() != null) {
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator);
}
if (exactSearcherContext.getMatchedDocs() != null
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
}
return searchTopK(iterator, exactSearcherContext.getK());
return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue());
}

/**
* Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search.
* Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores
* to filter out the documents that does not have given min score.
* @param leafReaderContext
* @param exactSearcherContext
* @param iterator {@link KNNIterator}
* @return Map of docId and score
* @throws IOException exception raised by iterator during traversal
*/
private Map<Integer, Float> doRadialSearch(
LeafReaderContext leafReaderContext,
ExactSearcherContext exactSearcherContext,
KNNIterator iterator
) throws IOException {
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());
final KNNEngine engine = FieldInfoExtractor.extractKNNEngine(fieldInfo);
if (KNNEngine.FAISS != engine) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support radial search", engine));
}
final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);
final float minScore = spaceType.scoreTranslation(knnQuery.getRadius());
return filterDocsByMinScore(exactSearcherContext, iterator, minScore);
}

private Map<Integer, Float> scoreAllDocs(KNNIterator iterator) throws IOException {
Expand All @@ -71,15 +106,17 @@ private Map<Integer, Float> scoreAllDocs(KNNIterator iterator) throws IOExceptio
return docToScore;
}

private Map<Integer, Float> searchTopK(KNNIterator iterator, int k) throws IOException {
private Map<Integer, Float> searchTopCandidates(KNNIterator iterator, int limit, @NonNull Predicate<Float> filterScore)
throws IOException {
// Creating min heap and init with MAX DocID and Score as -INF.
final HitQueue queue = new HitQueue(k, true);
final HitQueue queue = new HitQueue(limit, true);
ScoreDoc topDoc = queue.top();
final Map<Integer, Float> docToScore = new HashMap<>();
int docId;
while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
if (iterator.score() > topDoc.score) {
topDoc.score = iterator.score();
final float currentScore = iterator.score();
if (filterScore.test(currentScore) && currentScore > topDoc.score) {
topDoc.score = currentScore;
topDoc.doc = docId;
// As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we
// have seen till now on top.
Expand All @@ -98,10 +135,16 @@ private Map<Integer, Float> searchTopK(KNNIterator iterator, int k) throws IOExc
final ScoreDoc doc = queue.pop();
docToScore.put(doc.doc, doc.score);
}

return docToScore;
}

private Map<Integer, Float> filterDocsByMinScore(ExactSearcherContext context, KNNIterator iterator, float minScore)
throws IOException {
int maxResultWindow = context.getKnnQuery().getContext().getMaxResultWindow();
Predicate<Float> scoreGreaterThanOrEqualToMinScore = score -> score >= minScore;
return searchTopCandidates(iterator, maxResultWindow, scoreGreaterThanOrEqualToMinScore);
}

private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
Expand Down
Loading
Loading