From df7627c6e580843beb2361f5c2ec3519efd52280 Mon Sep 17 00:00:00 2001 From: Tejas Shah Date: Tue, 6 Aug 2024 11:34:08 -0700 Subject: [PATCH 1/6] Introduces NativeEngineKNNQuery which executes ANN on rewrite (#1877) Signed-off-by: Tejas Shah --- CHANGELOG.md | 1 + .../common/featureflags/KNNFeatureFlags.java | 46 +++++ .../org/opensearch/knn/index/KNNSettings.java | 24 ++- .../knn/index/query/KNNQueryFactory.java | 9 +- .../opensearch/knn/index/query/KNNScorer.java | 7 + .../opensearch/knn/index/query/KNNWeight.java | 24 ++- .../query/nativelib/DocAndScoreQuery.java | 185 +++++++++++++++++ .../nativelib/NativeEngineKnnVectorQuery.java | 140 +++++++++++++ .../featureflags/KNNFeatureFlagsTests.java | 34 ++++ .../knn/index/query/KNNQueryBuilderTests.java | 14 ++ .../knn/index/query/KNNQueryFactoryTests.java | 34 ++++ .../knn/index/query/KNNWeightTests.java | 4 +- .../nativelib/DocAndScoreQueryTests.java | 99 +++++++++ .../NativeEngineKNNVectorQueryIT.java | 190 ++++++++++++++++++ .../NativeEngineKNNVectorQueryTests.java | 156 ++++++++++++++ 15 files changed, 952 insertions(+), 15 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java create mode 100644 src/main/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQuery.java create mode 100644 src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java create mode 100644 src/test/java/org/opensearch/knn/common/featureflags/KNNFeatureFlagsTests.java create mode 100644 src/test/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQueryTests.java create mode 100644 src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryIT.java create mode 100644 src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 0955338f3e..e851b8f369 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,3 +30,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920) * Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931) * Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) +* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877) \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java b/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java new file mode 100644 index 0000000000..21160fc2d0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/common/featureflags/KNNFeatureFlags.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.common.featureflags; + +import com.google.common.annotations.VisibleForTesting; +import lombok.experimental.UtilityClass; +import org.opensearch.common.settings.Setting; +import org.opensearch.knn.index.KNNSettings; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.opensearch.common.settings.Setting.Property.Dynamic; +import static org.opensearch.common.settings.Setting.Property.NodeScope; + +/** + * Class to manage KNN feature flags + */ +@UtilityClass +public class KNNFeatureFlags { + + // Feature flags + private static final String KNN_LAUNCH_QUERY_REWRITE_ENABLED = "knn.feature.query.rewrite.enabled"; + private static final boolean KNN_LAUNCH_QUERY_REWRITE_ENABLED_DEFAULT = true; + + @VisibleForTesting + public static final Setting KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING = Setting.boolSetting( + KNN_LAUNCH_QUERY_REWRITE_ENABLED, + KNN_LAUNCH_QUERY_REWRITE_ENABLED_DEFAULT, + NodeScope, + Dynamic + ); + + public static List> getFeatureFlags() { + return Stream.of(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING).collect(Collectors.toUnmodifiableList()); + } + + public static boolean isKnnQueryRewriteEnabled() { + return Boolean.parseBoolean(KNNSettings.state().getSettingValue(KNN_LAUNCH_QUERY_REWRITE_ENABLED).toString()); + } +} diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index d2c04b94ef..33c7ff410b 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -9,17 +9,17 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchParseException; -import org.opensearch.cluster.metadata.IndexMetadata; -import org.opensearch.core.action.ActionListener; import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest; import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.unit.ByteSizeUnit; import org.opensearch.core.common.unit.ByteSizeValue; -import org.opensearch.common.unit.TimeValue; import org.opensearch.index.IndexModule; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryCacheManagerDto; @@ -28,20 +28,22 @@ import org.opensearch.monitor.os.OsProbe; import java.security.InvalidParameterException; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; +import static java.util.stream.Collectors.toUnmodifiableMap; import static org.opensearch.common.settings.Setting.Property.Dynamic; import static org.opensearch.common.settings.Setting.Property.IndexScope; import static org.opensearch.common.settings.Setting.Property.NodeScope; -import static org.opensearch.core.common.unit.ByteSizeValue.parseBytesSizeValue; import static org.opensearch.common.unit.MemorySizeValue.parseBytesSizeValueOrHeapRatio; +import static org.opensearch.core.common.unit.ByteSizeValue.parseBytesSizeValue; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.getFeatureFlags; /** * This class defines @@ -289,6 +291,9 @@ public class KNNSettings { } }; + private final static Map> FEATURE_FLAGS = getFeatureFlags().stream() + .collect(toUnmodifiableMap(Setting::getKey, Function.identity())); + private ClusterService clusterService; private Client client; @@ -326,7 +331,7 @@ private void setSettingsUpdateConsumers() { ); NativeMemoryCacheManager.getInstance().rebuildCache(builder.build()); - }, new ArrayList<>(dynamicCacheSettings.values())); + }, Stream.concat(dynamicCacheSettings.values().stream(), FEATURE_FLAGS.values().stream()).collect(Collectors.toUnmodifiableList())); } /** @@ -346,6 +351,10 @@ private Setting getSetting(String key) { return dynamicCacheSettings.get(key); } + if (FEATURE_FLAGS.containsKey(key)) { + return FEATURE_FLAGS.get(key); + } + if (KNN_CIRCUIT_BREAKER_TRIGGERED.equals(key)) { return KNN_CIRCUIT_BREAKER_TRIGGERED_SETTING; } @@ -390,7 +399,8 @@ public List> getSettings() { KNN_FAISS_AVX2_DISABLED_SETTING, KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING ); - return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()).collect(Collectors.toList()); + return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) + .collect(Collectors.toList()); } public static boolean isKNNPluginEnabled() { diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index ee9a12a41d..f3161b2dba 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -16,12 +16,14 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import java.util.Locale; import java.util.Map; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.isKnnQueryRewriteEnabled; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; /** @@ -98,9 +100,10 @@ public static Query create(CreateQueryRequest createQueryRequest) { methodParameters ); + KNNQuery knnQuery = null; switch (vectorDataType) { case BINARY: - return KNNQuery.builder() + knnQuery = KNNQuery.builder() .field(fieldName) .byteQueryVector(byteVector) .indexName(indexName) @@ -110,8 +113,9 @@ public static Query create(CreateQueryRequest createQueryRequest) { .filterQuery(validatedFilterQuery) .vectorDataType(vectorDataType) .build(); + break; default: - return KNNQuery.builder() + knnQuery = KNNQuery.builder() .field(fieldName) .queryVector(vector) .indexName(indexName) @@ -122,6 +126,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { .vectorDataType(vectorDataType) .build(); } + return isKnnQueryRewriteEnabled() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery; } Integer requestEfSearch = null; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java index 02dc86e807..99962d3074 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNScorer.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNScorer.java @@ -87,6 +87,13 @@ public float score() throws IOException { public int docID() { return docIdsIter.docID(); } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Scorer)) return false; + return getWeight().equals(((Scorer) obj).getWeight()); + } }; + } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index f54d8328e3..f886525254 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -108,6 +108,22 @@ public Explanation explain(LeafReaderContext context, int doc) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { + final Map docIdToScoreMap = searchLeaf(context); + if (docIdToScoreMap.isEmpty()) { + return KNNScorer.emptyScorer(this); + } + + return convertSearchResponseToScorer(docIdToScoreMap); + } + + /** + * Executes k nearest neighbor search for a segment to get the top K results + * This is made public purely to be able to be reused in {@link org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery} + * + * @param context LeafReaderContext + * @return A Map of docId to scores for top k results + */ + public Map searchLeaf(LeafReaderContext context) throws IOException { final BitSet filterBitSet = getFilteredDocsBitSet(context); int cardinality = filterBitSet.cardinality(); @@ -115,7 +131,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { // We should give this condition a deeper look that where it should be placed. For now I feel this is a good // place, if (filterWeight != null && cardinality == 0) { - return KNNScorer.emptyScorer(this); + return Collections.emptyMap(); } final Map docIdsToScoreMap = new HashMap<>(); @@ -129,7 +145,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { } else { Map annResults = doANNSearch(context, filterBitSet, cardinality); if (annResults == null) { - return null; + return Collections.emptyMap(); } if (canDoExactSearchAfterANNSearch(cardinality, annResults.size())) { log.debug( @@ -144,9 +160,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException { docIdsToScoreMap.putAll(annResults); } if (docIdsToScoreMap.isEmpty()) { - return KNNScorer.emptyScorer(this); + return Collections.emptyMap(); } - return convertSearchResponseToScorer(docIdsToScoreMap); + return docIdsToScoreMap; } private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException { diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQuery.java new file mode 100644 index 0000000000..f1a91d8784 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQuery.java @@ -0,0 +1,185 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.nativelib; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * This is the same as {@link org.apache.lucene.search.AbstractKnnVectorQuery.DocAndScoreQuery} + */ +final class DocAndScoreQuery extends Query { + + private final int k; + private final int[] docs; + private final float[] scores; + private final int[] segmentStarts; + private final Object contextIdentity; + + DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { + this.k = k; + this.docs = docs; + this.scores = scores; + this.segmentStarts = segmentStarts; + this.contextIdentity = contextIdentity; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) { + if (searcher.getIndexReader().getContext().id() != contextIdentity) { + throw new IllegalStateException("This DocAndScore query was created by a different reader"); + } + return new Weight(this) { + @Override + public Explanation explain(LeafReaderContext context, int doc) { + int found = Arrays.binarySearch(docs, doc + context.docBase); + if (found < 0) { + return Explanation.noMatch("not in top " + k); + } + return Explanation.match(scores[found] * boost, "within top " + k); + } + + @Override + public int count(LeafReaderContext context) { + return segmentStarts[context.ord + 1] - segmentStarts[context.ord]; + } + + @Override + public Scorer scorer(LeafReaderContext context) { + if (segmentStarts[context.ord] == segmentStarts[context.ord + 1]) { + return null; + } + return new Scorer(this) { + final int lower = segmentStarts[context.ord]; + final int upper = segmentStarts[context.ord + 1]; + int upTo = -1; + + @Override + public DocIdSetIterator iterator() { + return new DocIdSetIterator() { + @Override + public int docID() { + return docIdNoShadow(); + } + + @Override + public int nextDoc() { + if (upTo == -1) { + upTo = lower; + } else { + ++upTo; + } + return docIdNoShadow(); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return upper - lower; + } + }; + } + + @Override + public float getMaxScore(int docId) { + docId += context.docBase; + float maxScore = 0; + for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docId; idx++) { + maxScore = Math.max(maxScore, scores[idx]); + } + return maxScore * boost; + } + + @Override + public float score() { + return scores[upTo] * boost; + } + + @Override + public int advanceShallow(int docid) { + int start = Math.max(upTo, lower); + int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase); + if (docidIndex < 0) { + docidIndex = -1 - docidIndex; + } + if (docidIndex >= upper) { + return NO_MORE_DOCS; + } + return docs[docidIndex]; + } + + /** + * move the implementation of docID() into a differently-named method so we can call it + * from DocIDSetIterator.docID() even though this class is anonymous + * + * @return the current docid + */ + private int docIdNoShadow() { + if (upTo == -1) { + return -1; + } + if (upTo >= upper) { + return NO_MORE_DOCS; + } + return docs[upTo] - context.docBase; + } + + @Override + public int docID() { + return docIdNoShadow(); + } + }; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } + }; + } + + @Override + public String toString(String field) { + return "DocAndScore[" + k + "][docs:" + Arrays.toString(docs) + ", scores:" + Arrays.toString(scores) + "]"; + } + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + + @Override + public boolean equals(Object obj) { + if (!sameClassAs(obj)) { + return false; + } + return contextIdentity == ((DocAndScoreQuery) obj).contextIdentity + && Arrays.equals(docs, ((DocAndScoreQuery) obj).docs) + && Arrays.equals(scores, ((DocAndScoreQuery) obj).scores); + } + + @Override + public int hashCode() { + return Objects.hash(classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores)); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java new file mode 100644 index 0000000000..6b9a40a9cf --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -0,0 +1,140 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.nativelib; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.Bits; +import org.opensearch.knn.index.query.KNNQuery; +import org.opensearch.knn.index.query.KNNWeight; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.Callable; + +/** + * {@link KNNQuery} executes approximate nearest neighbor search (ANN) on a segment level. + * {@link NativeEngineKnnVectorQuery} executes approximate nearest neighbor search but gives + * us the control to combine the top k results in each leaf and post process the results just + * for k-NN query if required. This is done by overriding rewrite method to execute ANN on each leaf + * {@link KNNQuery} does not give the ability to post process segment results. + */ +@Getter +@RequiredArgsConstructor +public class NativeEngineKnnVectorQuery extends Query { + + private final KNNQuery knnQuery; + + @Override + public Query rewrite(final IndexSearcher indexSearcher) throws IOException { + final IndexReader reader = indexSearcher.getIndexReader(); + final KNNWeight knnWeight = (KNNWeight) knnQuery.createWeight(indexSearcher, ScoreMode.COMPLETE, 1); + List leafReaderContexts = reader.leaves(); + + List> tasks = new ArrayList<>(leafReaderContexts.size()); + for (LeafReaderContext leafReaderContext : leafReaderContexts) { + tasks.add(() -> searchLeaf(leafReaderContext, knnWeight)); + } + TopDocs[] perLeafResults = indexSearcher.getTaskExecutor().invokeAll(tasks).toArray(TopDocs[]::new); + // TopDocs.merge requires perLeafResults to be sorted in descending order. + TopDocs topK = TopDocs.merge(knnQuery.getK(), perLeafResults); + if (topK.scoreDocs.length == 0) { + return new MatchNoDocsQuery(); + } + return createRewrittenQuery(reader, topK); + } + + private Query createRewrittenQuery(IndexReader reader, TopDocs topK) { + int len = topK.scoreDocs.length; + Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc)); + int[] docs = new int[len]; + float[] scores = new float[len]; + for (int i = 0; i < len; i++) { + docs[i] = topK.scoreDocs[i].doc; + scores[i] = topK.scoreDocs[i].score; + } + int[] segmentStarts = findSegmentStarts(reader, docs); + return new DocAndScoreQuery(knnQuery.getK(), docs, scores, segmentStarts, reader.getContext().id()); + } + + private static int[] findSegmentStarts(IndexReader reader, int[] docs) { + int[] starts = new int[reader.leaves().size() + 1]; + starts[starts.length - 1] = docs.length; + if (starts.length == 2) { + return starts; + } + int resultIndex = 0; + for (int i = 1; i < starts.length - 1; i++) { + int upper = reader.leaves().get(i).docBase; + resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper); + if (resultIndex < 0) { + resultIndex = -1 - resultIndex; + } + starts[i] = resultIndex; + } + return starts; + } + + private TopDocs searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight) throws IOException { + int totalHits = 0; + final Map leafDocScores = queryWeight.searchLeaf(ctx); + final List scoreDocs = new ArrayList<>(); + final Bits liveDocs = ctx.reader().getLiveDocs(); + + if (!leafDocScores.isEmpty()) { + final List> topScores = new ArrayList<>(leafDocScores.entrySet()); + topScores.sort(Map.Entry.comparingByValue().reversed()); + + for (Map.Entry entry : topScores) { + if (liveDocs == null || liveDocs.get(entry.getKey())) { + ScoreDoc scoreDoc = new ScoreDoc(entry.getKey() + ctx.docBase, entry.getValue()); + scoreDocs.add(scoreDoc); + totalHits++; + } + } + } + + return new TopDocs(new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), scoreDocs.toArray(ScoreDoc[]::new)); + } + + @Override + public String toString(String field) { + return this.getClass().getSimpleName() + "[" + field + "]..." + KNNQuery.class.getSimpleName() + "[" + knnQuery.toString() + "]"; + } + + @Override + public void visit(QueryVisitor visitor) { + visitor.visitLeaf(this); + } + + @Override + public boolean equals(Object obj) { + if (!sameClassAs(obj)) { + return false; + } + return knnQuery == ((NativeEngineKnnVectorQuery) obj).knnQuery; + } + + @Override + public int hashCode() { + return Objects.hash(classHash(), knnQuery.hashCode()); + } +} diff --git a/src/test/java/org/opensearch/knn/common/featureflags/KNNFeatureFlagsTests.java b/src/test/java/org/opensearch/knn/common/featureflags/KNNFeatureFlagsTests.java new file mode 100644 index 0000000000..c3a8a1615d --- /dev/null +++ b/src/test/java/org/opensearch/knn/common/featureflags/KNNFeatureFlagsTests.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.common.featureflags; + +import org.mockito.Mock; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; + +import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.isKnnQueryRewriteEnabled; + +public class KNNFeatureFlagsTests extends KNNTestCase { + + @Mock + ClusterSettings clusterSettings; + + public void setUp() throws Exception { + super.setUp(); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + KNNSettings.state().setClusterService(clusterService); + } + + public void testIsFeatureEnabled() throws Exception { + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(false); + assertFalse(isKnnQueryRewriteEnabled()); + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(true); + assertTrue(isKnnQueryRewriteEnabled()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 0241a9afbe..0b918bd9ed 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -11,10 +11,12 @@ import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.junit.Before; import org.opensearch.Version; import org.opensearch.cluster.ClusterModule; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; @@ -31,6 +33,7 @@ import org.opensearch.knn.index.util.KNNClusterUtil; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; @@ -51,6 +54,7 @@ import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; @@ -67,6 +71,16 @@ public class KNNQueryBuilderTests extends KNNTestCase { protected static final String TEXT_FIELD_NAME = "some_field"; protected static final String TEXT_VALUE = "some_value"; + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + ClusterSettings clusterSettings = mock(ClusterSettings.class); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(false); + KNNSettings.state().setClusterService(clusterService); + } + public void testInvalidK() { float[] queryVector = { 1.0f, 1.0f }; diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index c74a79946a..7bacc7d10c 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -14,8 +14,11 @@ import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.search.join.ToChildBlockJoinQuery; +import org.junit.Before; +import org.mockito.Mock; import org.mockito.MockedConstruction; import org.mockito.Mockito; +import org.opensearch.common.settings.ClusterSettings; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.query.QueryBuilder; @@ -23,8 +26,10 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.search.NestedHelper; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import java.util.Arrays; import java.util.List; @@ -36,6 +41,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING; public class KNNQueryFactoryTests extends KNNTestCase { private static final String FILTER_FILED_NAME = "foo"; @@ -50,8 +56,21 @@ public class KNNQueryFactoryTests extends KNNTestCase { private final int testK = 10; private final Map methodParameters = Map.of(METHOD_PARAMETER_EF_SEARCH, 100); + @Mock + ClusterSettings clusterSettings; + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(false); + KNNSettings.state().setClusterService(clusterService); + } + public void testCreateCustomKNNQuery() { for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) { + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(false); Query query = KNNQueryFactory.create( knnEngine, testIndexName, @@ -61,6 +80,15 @@ public void testCreateCustomKNNQuery() { DEFAULT_VECTOR_DATA_TYPE_FIELD ); assertTrue(query instanceof KNNQuery); + assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); + assertEquals(testFieldName, ((KNNQuery) query).getField()); + assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector()); + assertEquals(testK, ((KNNQuery) query).getK()); + + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(true); + query = KNNQueryFactory.create(knnEngine, testIndexName, testFieldName, testQueryVector, testK, DEFAULT_VECTOR_DATA_TYPE_FIELD); + assertTrue(query instanceof NativeEngineKnnVectorQuery); + query = ((NativeEngineKnnVectorQuery) query).getKnnQuery(); assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); assertEquals(testFieldName, ((KNNQuery) query).getField()); @@ -392,6 +420,7 @@ public void testCreate_whenBinary_thenSuccess() { when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); BitSetProducer parentFilter = mock(BitSetProducer.class); when(mockQueryShardContext.getParentFilter()).thenReturn(parentFilter); + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() .knnEngine(KNNEngine.FAISS) .indexName(testIndexName) @@ -407,5 +436,10 @@ public void testCreate_whenBinary_thenSuccess() { assertTrue(query instanceof KNNQuery); assertNotNull(((KNNQuery) query).getByteQueryVector()); assertNull(((KNNQuery) query).getQueryVector()); + + when(clusterSettings.get(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING)).thenReturn(true); + query = KNNQueryFactory.create(createQueryRequest); + assertTrue(query instanceof NativeEngineKnnVectorQuery); } + } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index c7077eace1..c5abc964db 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -350,7 +350,7 @@ public void testShardWithoutFiles() { when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); final Scorer knnScorer = knnWeight.scorer(leafReaderContext); - assertNull(knnScorer); + assertEquals(KNNScorer.emptyScorer(knnWeight), knnScorer); } @SneakyThrows @@ -394,7 +394,7 @@ public void testEmptyQueryResults() { when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); final Scorer knnScorer = knnWeight.scorer(leafReaderContext); - assertNull(knnScorer); + assertEquals(KNNScorer.emptyScorer(knnWeight), knnScorer); } @SneakyThrows diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQueryTests.java new file mode 100644 index 0000000000..185cb5d471 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/DocAndScoreQueryTests.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.nativelib; + +import lombok.SneakyThrows; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexReaderContext; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.mockito.Mock; +import org.opensearch.test.OpenSearchTestCase; + +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; + +public class DocAndScoreQueryTests extends OpenSearchTestCase { + + @Mock + private LeafReaderContext leaf1; + @Mock + private IndexSearcher indexSearcher; + @Mock + private IndexReader reader; + @Mock + private IndexReaderContext readerContext; + + private DocAndScoreQuery objectUnderTest; + + @Override + public void setUp() throws Exception { + super.setUp(); + openMocks(this); + + when(indexSearcher.getIndexReader()).thenReturn(reader); + when(reader.getContext()).thenReturn(readerContext); + when(readerContext.id()).thenReturn(1); + } + + // Note: cannot test with multi leaf as there LeafReaderContext is readonly with no getters for some fields to mock + public void testScorer() throws Exception { + // Given + int[] expectedDocs = { 0, 1, 2, 3, 4 }; + float[] expectedScores = { 0.1f, 1.2f, 2.3f, 5.1f, 3.4f }; + int[] findSegments = { 0, 2, 5 }; + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + + // When + Scorer scorer1 = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1).scorer(leaf1); + DocIdSetIterator iterator1 = scorer1.iterator(); + Scorer scorer2 = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1).scorer(leaf1); + DocIdSetIterator iterator2 = scorer2.iterator(); + + int[] actualDocs = new int[2]; + float[] actualScores = new float[2]; + int index = 0; + while (iterator1.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + actualDocs[index] = iterator1.docID(); + actualScores[index] = scorer1.score(); + ++index; + } + + // Then + assertEquals(2, iterator1.cost()); + assertArrayEquals(new int[] { 0, 1 }, actualDocs); + assertArrayEquals(new float[] { 0.1f, 1.2f }, actualScores, 0.0001f); + + assertEquals(1.2f, scorer2.getMaxScore(1), 0.0001f); + assertEquals(iterator2.advance(1), 1); + } + + @SneakyThrows + public void testWeight() { + // Given + int[] expectedDocs = { 0, 1, 2, 3, 4 }; + float[] expectedScores = { 0.1f, 1.2f, 2.3f, 5.1f, 3.4f }; + int[] findSegments = { 0, 2, 5 }; + Explanation expectedExplanation = Explanation.match(1.2f, "within top 4"); + + // When + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + Weight weight = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1); + Explanation explanation = weight.explain(leaf1, 1); + + // Then + assertEquals(objectUnderTest, weight.getQuery()); + assertTrue(weight.isCacheable(leaf1)); + assertEquals(2, weight.count(leaf1)); + assertEquals(expectedExplanation, explanation); + assertEquals(Explanation.noMatch("not in top 4"), weight.explain(leaf1, 9)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryIT.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryIT.java new file mode 100644 index 0000000000..29f3689ab4 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryIT.java @@ -0,0 +1,190 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.nativelib; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import com.google.common.primitives.Floats; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.junit.BeforeClass; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.KNNResult; +import org.opensearch.knn.TestUtils; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.FaissHNSWFlatE2EIT; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.plugin.script.KNNScoringUtil; + +import java.io.IOException; +import java.net.URL; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ThreadLocalRandom; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; + +@AllArgsConstructor +public class NativeEngineKNNVectorQueryIT extends KNNRestTestCase { + + private String description; + private int k; + private Map methodParameters; + private boolean deleteRandomDocs; + + static TestUtils.TestData testData; + + @BeforeClass + public static void setUpClass() throws IOException { + if (FaissHNSWFlatE2EIT.class.getClassLoader() == null) { + throw new IllegalStateException("ClassLoader of FaissIT Class is null"); + } + URL testIndexVectors = FaissHNSWFlatE2EIT.class.getClassLoader().getResource("data/test_vectors_1000x128.json"); + URL testQueries = FaissHNSWFlatE2EIT.class.getClassLoader().getResource("data/test_queries_100x128.csv"); + assert testIndexVectors != null; + assert testQueries != null; + testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath()); + } + + @ParametersFactory(argumentFormatting = "description:%1$s; k:%2$s; efSearch:%3$s, deleteDocs:%4$s") + public static Collection parameters() { + return Arrays.asList( + $$( + $("test without deletedocs", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), false), + $("test with deletedocs", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), true) + ) + ); + } + + @SneakyThrows + public void testResultComparisonSanity() { + String indexName = "test-index-1"; + String fieldName = "test-field-1"; + + SpaceType spaceType = SpaceType.L2; + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, 16) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, 32) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, 32) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents in the index + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + // Delete few Docs + if (deleteRandomDocs) { + final Set docIdsToBeDeleted = new HashSet<>(); + while (docIdsToBeDeleted.size() < 10) { + docIdsToBeDeleted.add(randomInt(testData.indexData.docs.length - 1)); + } + + for (Integer id : docIdsToBeDeleted) { + deleteKnnDoc(indexName, Integer.toString(testData.indexData.docs[id])); + } + refreshAllNonSystemIndices(); + forceMergeKnnIndex(indexName, 3); + + assertEquals(testData.indexData.docs.length - 10, getDocCount(indexName)); + } + + int queryIndex = ThreadLocalRandom.current().nextInt(testData.queries.length); + // Test search queries + final KNNQueryBuilder queryBuilder = KNNQueryBuilder.builder() + .fieldName(fieldName) + .vector(testData.queries[queryIndex]) + .k(k) + .methodParameters(methodParameters) + .build(); + Response nativeEngineResponse = searchKNNIndex(indexName, queryBuilder, k); + String responseBody = EntityUtils.toString(nativeEngineResponse.getEntity()); + List nativeEngineKnnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, nativeEngineKnnResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, fieldName); + for (int j = 0; j < k; j++) { + float[] primitiveArray = nativeEngineKnnResults.get(j).getVector(); + assertEquals( + KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[queryIndex], primitiveArray), spaceType), + actualScores.get(j), + 0.0001 + ); + } + + updateClusterSettings("knn.feature.query.rewrite.enabled", false); + Response launchControlDisabledResponse = searchKNNIndex(indexName, queryBuilder, k); + String launchControlDisabledResponseString = EntityUtils.toString(launchControlDisabledResponse.getEntity()); + List knnResults = parseSearchResponse(launchControlDisabledResponseString, fieldName); + assertEquals(k, knnResults.size()); + + assertEquals(nativeEngineKnnResults, knnResults); + + // Delete index + deleteKNNIndex(indexName); + + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java new file mode 100644 index 0000000000..1e4b11a12f --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -0,0 +1,156 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.nativelib; + +import lombok.SneakyThrows; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexReaderContext; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.util.Bits; +import org.mockito.ArgumentMatchers; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.opensearch.knn.index.query.KNNQuery; +import org.opensearch.knn.index.query.KNNWeight; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; + +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; + +public class NativeEngineKNNVectorQueryTests extends OpenSearchTestCase { + + @Mock + private IndexSearcher searcher; + @Mock + private IndexReader reader; + @Mock + private KNNQuery knnQuery; + @Mock + private KNNWeight knnWeight; + @Mock + private TaskExecutor taskExecutor; + @Mock + private IndexReaderContext indexReaderContext; + @Mock + private LeafReaderContext leaf1; + @Mock + private LeafReaderContext leaf2; + @Mock + private LeafReader leafReader1; + @Mock + private LeafReader leafReader2; + + @InjectMocks + private NativeEngineKnnVectorQuery objectUnderTest; + + @Override + public void setUp() throws Exception { + super.setUp(); + openMocks(this); + + when(leaf1.reader()).thenReturn(leafReader1); + when(leaf2.reader()).thenReturn(leafReader2); + + when(searcher.getIndexReader()).thenReturn(reader); + when(knnQuery.createWeight(searcher, ScoreMode.COMPLETE, 1)).thenReturn(knnWeight); + + when(searcher.getTaskExecutor()).thenReturn(taskExecutor); + when(taskExecutor.invokeAll(ArgumentMatchers.>anyList())).thenAnswer(invocationOnMock -> { + List> callables = invocationOnMock.getArgument(0); + List topDocs = new ArrayList<>(); + for (Callable callable : callables) { + topDocs.add(callable.call()); + } + return topDocs; + }); + + when(reader.getContext()).thenReturn(indexReaderContext); + } + + @SneakyThrows + public void testMultiLeaf() { + // Given + List leaves = List.of(leaf1, leaf2); + when(reader.leaves()).thenReturn(leaves); + + when(knnWeight.searchLeaf(leaf1)).thenReturn(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f)); + when(knnWeight.searchLeaf(leaf2)).thenReturn(Map.of(4, 3.4f, 3, 5.1f)); + + // Making sure there is deleted docs in one of the segments + Bits liveDocs = mock(Bits.class); + when(leafReader1.getLiveDocs()).thenReturn(liveDocs); + when(leafReader2.getLiveDocs()).thenReturn(null); + + when(liveDocs.get(anyInt())).thenReturn(true); + when(liveDocs.get(2)).thenReturn(false); + when(liveDocs.get(1)).thenReturn(false); + + // k=4 to make sure we get topk results even if docs are deleted/less in one of the leaves + when(knnQuery.getK()).thenReturn(4); + + when(indexReaderContext.id()).thenReturn(1); + int[] expectedDocs = { 0, 3, 4 }; + float[] expectedScores = { 1.2f, 5.1f, 3.4f }; + int[] findSegments = { 0, 1, 3 }; + DocAndScoreQuery expected = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + + // When + Query actual = objectUnderTest.rewrite(searcher); + + // Then + assertEquals(expected, actual); + } + + @SneakyThrows + public void testSingleLeaf() { + // Given + List leaves = List.of(leaf1); + when(reader.leaves()).thenReturn(leaves); + when(knnWeight.searchLeaf(leaf1)).thenReturn(Map.of(0, 1.2f, 1, 5.1f, 2, 2.2f)); + when(knnQuery.getK()).thenReturn(4); + + when(indexReaderContext.id()).thenReturn(1); + int[] expectedDocs = { 0, 1, 2 }; + float[] expectedScores = { 1.2f, 5.1f, 2.2f }; + int[] findSegments = { 0, 3 }; + DocAndScoreQuery expected = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + + // When + Query actual = objectUnderTest.rewrite(searcher); + + // Then + assertEquals(expected, actual); + } + + @SneakyThrows + public void testNoMatch() { + // Given + List leaves = List.of(leaf1); + when(reader.leaves()).thenReturn(leaves); + when(knnWeight.searchLeaf(leaf1)).thenReturn(Collections.emptyMap()); + when(knnQuery.getK()).thenReturn(4); + // When + Query actual = objectUnderTest.rewrite(searcher); + + // Then + assertEquals(new MatchNoDocsQuery(), actual); + } +} From e3158f990d058b02568da617688fd4857d0d521b Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 7 Aug 2024 17:18:03 -0700 Subject: [PATCH 2/6] Fix graph merge stats size calculation (#1844) * Fix graph merge stats size calculation Signed-off-by: Ryan Bogan * Add changelog entry Signed-off-by: Ryan Bogan * Add javadocs Signed-off-by: Ryan Bogan * Make calculations easier to read Signed-off-by: Ryan Bogan * Remove java overhead from calculations Signed-off-by: Ryan Bogan * Change from serialization mode to vector data type for calculations Signed-off-by: Ryan Bogan * Minor change to if statements Signed-off-by: Ryan Bogan --------- Signed-off-by: Ryan Bogan --- CHANGELOG.md | 1 + .../KNN80Codec/KNN80DocValuesConsumer.java | 7 +-- .../knn/index/codec/util/KNNCodecUtil.java | 54 ++++++------------- .../index/codec/util/KNNCodecUtilTests.java | 19 +++++++ 4 files changed, 40 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e851b8f369..44f387533a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Bug Fixes * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874) * Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917) +* Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 8e191ac5f3..989af4063b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -131,6 +131,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, NativeIndexCreator indexCreator; KNNCodecUtil.Pair pair; Map fieldAttributes = field.attributes(); + VectorDataType vectorDataType; if (fieldAttributes.containsKey(MODEL_ID)) { String modelId = fieldAttributes.get(MODEL_ID); @@ -138,12 +139,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, if (model.getModelBlob() == null) { throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); } - VectorDataType vectorDataType = model.getModelMetadata().getVectorDataType(); + vectorDataType = model.getModelMetadata().getVectorDataType(); pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath); } else { // get vector data type from field attributes or provide default value - VectorDataType vectorDataType = VectorDataType.get( + vectorDataType = VectorDataType.get( fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()) ); pair = KNNCodecUtil.getPair(values, getVectorTransfer(vectorDataType)); @@ -156,7 +157,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, return; } - long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode); + long arraySize = calculateArraySize(pair.docs.length, pair.getDimension(), vectorDataType); if (isMerge) { KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment(); diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index d208d8179b..ea14fe8834 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -11,6 +11,7 @@ import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues; import org.opensearch.knn.index.codec.transfer.VectorTransfer; @@ -21,12 +22,6 @@ public class KNNCodecUtil { // Floats are 4 bytes in size public static final int FLOAT_BYTE_SIZE = 4; - // References to objects are 4 bytes in size - public static final int JAVA_REFERENCE_SIZE = 4; - // Each array in Java has a header that is 12 bytes - public static final int JAVA_ARRAY_HEADER_SIZE = 12; - // Java rounds each array size up to multiples of 8 bytes - public static final int JAVA_ROUNDING_NUMBER = 8; @AllArgsConstructor public static final class Pair { @@ -67,39 +62,22 @@ public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final Vect ); } - public static long calculateArraySize(int numVectors, int vectorLength, SerializationMode serializationMode) { - if (serializationMode == SerializationMode.ARRAY) { - int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE; - if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { - vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; - } - int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE; - if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { - vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; - } - return vectorsSize; - } else if (serializationMode == SerializationMode.COLLECTION_OF_FLOATS) { - int vectorSize = vectorLength * FLOAT_BYTE_SIZE; - if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { - vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; - } - int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE); - if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { - vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; - } - return vectorsSize; - } else if (serializationMode == SerializationMode.COLLECTIONS_OF_BYTES) { - int vectorSize = vectorLength; - if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { - vectorSize += vectorSize % JAVA_ROUNDING_NUMBER; - } - int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE); - if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { - vectorsSize += vectorsSize % JAVA_ROUNDING_NUMBER; - } - return vectorsSize; + /** + * This method provides a rough estimate of the number of bytes used for storing an array with the given parameters. + * @param numVectors number of vectors in the array + * @param vectorLength the length of each vector + * @param vectorDataType type of data stored in each vector + * @return rough estimate of number of bytes used to store an array with the given parameters + */ + public static long calculateArraySize(int numVectors, int vectorLength, VectorDataType vectorDataType) { + if (vectorDataType == VectorDataType.FLOAT) { + return numVectors * vectorLength * FLOAT_BYTE_SIZE; + } else if (vectorDataType == VectorDataType.BINARY || vectorDataType == VectorDataType.BYTE) { + return numVectors * vectorLength; } else { - throw new IllegalStateException("Unreachable code"); + throw new IllegalArgumentException( + "Float, binary, and byte are the only supported vector data types for array size calculation." + ); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java index 2ff0f08e51..47dd1dda99 100644 --- a/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/util/KNNCodecUtilTests.java @@ -9,6 +9,7 @@ import lombok.SneakyThrows; import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.util.BytesRef; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.transfer.VectorTransfer; import java.util.Arrays; @@ -18,6 +19,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.knn.index.codec.util.KNNCodecUtil.calculateArraySize; public class KNNCodecUtilTests extends TestCase { @SneakyThrows @@ -52,4 +54,21 @@ public void testGetPair_whenCalled_thenReturn() { assertEquals(dimension, pair.getDimension()); assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, pair.serializationMode); } + + public void testCalculateArraySize() { + int numVectors = 4; + int vectorLength = 10; + + // Float data type + VectorDataType vectorDataType = VectorDataType.FLOAT; + assertEquals(160, calculateArraySize(numVectors, vectorLength, vectorDataType)); + + // Byte data type + vectorDataType = VectorDataType.BYTE; + assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType)); + + // Binary data type + vectorDataType = VectorDataType.BINARY; + assertEquals(40, calculateArraySize(numVectors, vectorLength, vectorDataType)); + } } From 56698f774fde9e08bee7ad8372a17b7948362b87 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Thu, 8 Aug 2024 13:46:42 -0400 Subject: [PATCH 3/6] Fix CI for 2.17 (#1940) Signed-off-by: John Mazanec --- .github/workflows/backwards_compatibility_tests_workflow.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/backwards_compatibility_tests_workflow.yml b/.github/workflows/backwards_compatibility_tests_workflow.yml index 90a20eda89..c20201dab0 100644 --- a/.github/workflows/backwards_compatibility_tests_workflow.yml +++ b/.github/workflows/backwards_compatibility_tests_workflow.yml @@ -35,7 +35,7 @@ jobs: matrix: java: [ 21 ] os: [ubuntu-latest] - bwc_version : [ "2.0.1", "2.1.0", "2.2.1", "2.3.0", "2.4.1", "2.5.0", "2.6.0", "2.7.0", "2.8.0", "2.9.0", "2.10.0", "2.11.0", "2.12.0", "2.13.0", "2.14.0", "2.15.0", "2.16.0-SNAPSHOT"] + bwc_version : [ "2.0.1", "2.1.0", "2.2.1", "2.3.0", "2.4.1", "2.5.0", "2.6.0", "2.7.0", "2.8.0", "2.9.0", "2.10.0", "2.11.0", "2.12.0", "2.13.0", "2.14.0", "2.15.0", "2.16.0", "2.17.0-SNAPSHOT"] opensearch_version : [ "3.0.0-SNAPSHOT" ] exclude: - os: windows-latest @@ -114,7 +114,7 @@ jobs: matrix: java: [ 21 ] os: [ubuntu-latest] - bwc_version: [ "2.16.0-SNAPSHOT" ] + bwc_version: [ "2.17.0-SNAPSHOT" ] opensearch_version: [ "3.0.0-SNAPSHOT" ] name: k-NN Rolling-Upgrade BWC Tests From 2cd57e88665ef9ea3406fb90b49a9fc843901b10 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Sat, 10 Aug 2024 11:01:00 -0400 Subject: [PATCH 4/6] Refactor Around Mapper and Mapping (#1939) Refactors FieldMapper logic. It removes the LegacyFieldMapper and replaces it with a FlatFieldMapper. The FlatFieldMapper's role is to create fields that do not build ANN indices. Additionally, it puts dimension, model_id, and knn_method_context in a new KNNMappingConfig class and adds some safety checks around accessing them. This should make calling logic easier to handle. Lastly, it cleans up the parsing so that there isnt encoder parsing directly in the KNNVectorFieldMapper. Signed-off-by: John Mazanec --- CHANGELOG.md | 3 +- .../org/opensearch/knn/bwc/IndexingIT.java | 15 + .../codec/BasePerFieldKnnVectorsFormat.java | 15 +- .../KNN80Codec/KNN80DocValuesConsumer.java | 5 +- .../knn/index/engine/KNNMethodContext.java | 18 - .../index/mapper/FlatVectorFieldMapper.java | 91 ++++ .../knn/index/mapper/KNNMappingConfig.java | 38 ++ .../index/mapper/KNNVectorFieldMapper.java | 479 ++++++++---------- .../mapper/KNNVectorFieldMapperUtil.java | 236 ++++++--- .../knn/index/mapper/KNNVectorFieldType.java | 56 +- .../knn/index/mapper/LegacyFieldMapper.java | 130 ----- .../knn/index/mapper/LuceneFieldMapper.java | 96 +++- .../knn/index/mapper/MethodFieldMapper.java | 121 ++++- .../knn/index/mapper/ModelFieldMapper.java | 183 ++++++- .../index/mapper/PerDimensionProcessor.java | 51 ++ .../index/mapper/PerDimensionValidator.java | 80 +++ .../index/mapper/SpaceVectorValidator.java | 28 + .../knn/index/mapper/VectorValidator.java | 28 + .../knn/index/query/KNNQueryBuilder.java | 83 +-- .../org/opensearch/knn/plugin/KNNPlugin.java | 2 - .../java/org/opensearch/knn/KNNTestCase.java | 55 ++ .../knn/index/KNNMethodContextTests.java | 2 +- .../KNN80DocValuesConsumerTests.java | 56 -- .../knn/index/codec/KNNCodecTestCase.java | 46 +- .../mapper/KNNVectorFieldMapperTests.java | 211 ++++---- .../mapper/KNNVectorFieldMapperUtilTests.java | 59 +-- .../index/mapper/MethodFieldMapperTests.java | 46 +- .../knn/index/query/KNNQueryBuilderTests.java | 119 ++--- .../knn/integ/KNNScriptScoringIT.java | 19 +- .../script/KNNScoringSpaceFactoryTests.java | 12 +- .../plugin/script/KNNScoringSpaceTests.java | 45 +- .../script/KNNScoringSpaceUtilTests.java | 2 +- ...TrainingJobRouterTransportActionTests.java | 7 +- .../transport/TrainingModelRequestTests.java | 4 +- 34 files changed, 1465 insertions(+), 976 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java create mode 100644 src/main/java/org/opensearch/knn/index/mapper/KNNMappingConfig.java delete mode 100644 src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java create mode 100644 src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java create mode 100644 src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java create mode 100644 src/main/java/org/opensearch/knn/index/mapper/SpaceVectorValidator.java create mode 100644 src/main/java/org/opensearch/knn/index/mapper/VectorValidator.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 44f387533a..81c90802ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,4 +31,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920) * Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931) * Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) -* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877) \ No newline at end of file +* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877) +* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939) \ No newline at end of file diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/IndexingIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/IndexingIT.java index 2df79a3a24..1531dd0da1 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/IndexingIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/IndexingIT.java @@ -53,6 +53,21 @@ public void testKNNIndexDefaultLegacyFieldMapping() throws Exception { } } + // Ensure that when segments created with old mapping are forcemerged in new cluster, they + // succeed + public void testKNNIndexDefaultLegacyFieldMappingForceMerge() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + + if (isRunningAgainstOldCluster()) { + createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, DOC_ID, 100); + // Flush to ensure that index is not re-indexed when node comes back up + flush(testIndex, true); + } else { + forceMergeKnnIndex(testIndex); + } + } + // Custom Legacy Field Mapping // space_type : "linf", engine : "nmslib", m : 2, ef_construction : 2 public void testKNNIndexCustomLegacyFieldMapping() throws Exception { diff --git a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java index 2a3732d7e0..69229036ed 100644 --- a/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java +++ b/src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java @@ -13,6 +13,8 @@ import org.opensearch.knn.index.codec.params.KNNScalarQuantizedVectorsFormatParams; import org.opensearch.knn.index.codec.params.KNNVectorsFormatParams; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import java.util.Optional; @@ -66,16 +68,19 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) { ); return defaultFormatSupplier.get(); } - var type = (KNNVectorFieldType) mapperService.orElseThrow( + KNNVectorFieldType mappedFieldType = (KNNVectorFieldType) mapperService.orElseThrow( () -> new IllegalStateException( String.format("Cannot read field type for field [%s] because mapper service is not available", field) ) ).fieldType(field); - var params = type.getKnnMethodContext().getMethodComponentContext().getParameters(); - if (type.getKnnMethodContext().getKnnEngine() == KNNEngine.LUCENE - && params != null - && params.containsKey(METHOD_ENCODER_PARAMETER)) { + KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); + KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext() + .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); + + var params = knnMethodContext.getMethodComponentContext().getParameters(); + + if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE && params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) { KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( params, defaultMaxConnections, diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 989af4063b..55ac5c5978 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -14,6 +14,7 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; @@ -21,7 +22,6 @@ import org.opensearch.knn.index.codec.transfer.VectorTransferByte; import org.opensearch.knn.index.codec.transfer.VectorTransferFloat; import org.opensearch.knn.jni.JNIService; -import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.KNNCodecUtil; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.Model; @@ -216,8 +216,7 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa throws IOException { Map parameters = new HashMap<>(); Map fieldAttributes = fieldInfo.attributes(); - String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); - + String parametersString = fieldAttributes.get(PARAMETERS); // parametersString will be null when legacy mapper is used if (parametersString == null) { parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java index 7885761b93..d210483e60 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNMethodContext.java @@ -9,7 +9,6 @@ import lombok.Getter; import lombok.NonNull; import lombok.Setter; -import org.opensearch.Version; import org.opensearch.common.ValidationException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -20,7 +19,6 @@ import org.opensearch.index.mapper.MapperParsingException; import java.io.IOException; -import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; @@ -29,7 +27,6 @@ import org.opensearch.knn.training.VectorSpaceInfo; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; @@ -42,21 +39,6 @@ @Getter public class KNNMethodContext implements ToXContentFragment, Writeable { - private static KNNMethodContext defaultInstance = null; - - /** - * This is used only for testing - * @return default KNNMethodContext for testing - */ - public static synchronized KNNMethodContext getDefault() { - if (defaultInstance == null) { - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); - methodComponentContext.setIndexVersion(Version.CURRENT); - defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext); - } - return defaultInstance; - } - @NonNull private final KNNEngine knnEngine; @NonNull diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java new file mode 100644 index 0000000000..fffff30f41 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.apache.lucene.document.FieldType; +import org.opensearch.Version; +import org.opensearch.common.Explicit; +import org.opensearch.knn.index.VectorDataType; + +import java.util.Map; + +/** + * Mapper used when you dont want to build an underlying KNN struct - you just want to + * store vectors as doc values + */ +public class FlatVectorFieldMapper extends KNNVectorFieldMapper { + + private final PerDimensionValidator perDimensionValidator; + + public static FlatVectorFieldMapper createFieldMapper( + String fullname, + String simpleName, + Map metaValue, + VectorDataType vectorDataType, + Integer dimension, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + Version indexCreatedVersion + ) { + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, () -> dimension); + return new FlatVectorFieldMapper( + simpleName, + mappedFieldType, + multiFields, + copyTo, + ignoreMalformed, + stored, + hasDocValues, + indexCreatedVersion + ); + } + + private FlatVectorFieldMapper( + String simpleName, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + Version indexCreatedVersion + ) { + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, null); + this.perDimensionValidator = selectPerDimensionValidator(vectorDataType); + this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + this.fieldType.freeze(); + } + + private PerDimensionValidator selectPerDimensionValidator(VectorDataType vectorDataType) { + if (VectorDataType.BINARY == vectorDataType) { + return PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + } + + if (VectorDataType.BYTE == vectorDataType) { + return PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + } + + return PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; + } + + @Override + protected VectorValidator getVectorValidator() { + return VectorValidator.NOOP_VECTOR_VALIDATOR; + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator; + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + return PerDimensionProcessor.NOOP_PROCESSOR; + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNMappingConfig.java b/src/main/java/org/opensearch/knn/index/mapper/KNNMappingConfig.java new file mode 100644 index 0000000000..4fcd6e1bca --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNMappingConfig.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.index.engine.KNNMethodContext; + +import java.util.Optional; + +/** + * Class holds information about how the ANN indices are created. The design of this class ensures that we do not + * accidentally configure an index that has multiple ways it can be created. This class is immutable. + */ +public interface KNNMappingConfig { + /** + * + * @return Optional containing the modelId if created from model, otherwise empty + */ + default Optional getModelId() { + return Optional.empty(); + } + + /** + * + * @return Optional containing the KNNMethodContext if created from method, otherwise empty + */ + default Optional getKnnMethodContext() { + return Optional.empty(); + } + + /** + * + * @return the dimension of the index; for model based indices, it will be null + */ + int getDimension(); +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 3b94876454..40eaa12aeb 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -11,7 +11,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.function.Supplier; import lombok.extern.log4j.Log4j2; @@ -32,9 +31,8 @@ import org.opensearch.index.mapper.ParametrizedFieldMapper; import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.KnnCircuitBreakerException; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.VectorDataType; @@ -44,23 +42,17 @@ import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT; -import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; -import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; -import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; -import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; import static org.opensearch.knn.common.KNNValidationUtil.validateVectorDimension; -import static org.opensearch.knn.index.KNNSettings.KNN_INDEX; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createKNNMethodContextFromLegacy; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateIfCircuitBreakerIsNotTriggered; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateIfKNNPluginEnabled; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataType; -import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateVectorDataTypeWithKnnIndexSetting; +import static org.opensearch.knn.index.mapper.ModelFieldMapper.UNSET_MODEL_DIMENSION_IDENTIFIER; /** * Field Mapper for KNN vector type. Implementations of this class define what needs to be stored in Lucene's fieldType. @@ -76,10 +68,6 @@ private static KNNVectorFieldMapper toType(FieldMapper in) { return (KNNVectorFieldMapper) in; } - // We store the version of the index with the mapper as different version of Opensearch has different default - // values of KNN engine Algorithms hyperparameters. - protected Version indexCreatedVersion; - /** * Builder for KNNVectorFieldMapper. This class defines the set of parameters that can be applied to the knn_vector * field type @@ -89,25 +77,37 @@ public static class Builder extends ParametrizedFieldMapper.Builder { protected final Parameter stored = Parameter.storeParam(m -> toType(m).stored, false); protected final Parameter hasDocValues = Parameter.docValuesParam(m -> toType(m).hasDocValues, true); - protected final Parameter dimension = new Parameter<>(KNNConstants.DIMENSION, false, () -> -1, (n, c, o) -> { - if (o == null) { - throw new IllegalArgumentException("Dimension cannot be null"); - } - int value; - try { - value = XContentMapValues.nodeIntegerValue(o); - } catch (Exception exception) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "Unable to parse [dimension] from provided value [%s] for vector [%s]", o, name) - ); - } - if (value <= 0) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "Dimension value must be greater than 0 for vector: %s", name) - ); + protected final Parameter dimension = new Parameter<>( + KNNConstants.DIMENSION, + false, + () -> UNSET_MODEL_DIMENSION_IDENTIFIER, + (n, c, o) -> { + if (o == null) { + throw new IllegalArgumentException("Dimension cannot be null"); + } + int value; + try { + value = XContentMapValues.nodeIntegerValue(o); + } catch (Exception exception) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Unable to parse [dimension] from provided value [%s] for vector [%s]", o, name) + ); + } + if (value <= 0) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Dimension value must be greater than 0 for vector: %s", name) + ); + } + return value; + }, + m -> { + KNNMappingConfig knnMappingConfig = toType(m).fieldType().getKnnMappingConfig(); + if (knnMappingConfig.getModelId().isPresent()) { + return UNSET_MODEL_DIMENSION_IDENTIFIER; + } + return knnMappingConfig.getDimension(); } - return value; - }, m -> toType(m).dimension); + ); /** * data_type which defines the datatype of the vector values. This is an optional parameter and @@ -126,7 +126,12 @@ public static class Builder extends ParametrizedFieldMapper.Builder { * model template index. If this parameter is set, it will take precedence. This parameter is only relevant for * library indices that require training. */ - protected final Parameter modelId = Parameter.stringParam(KNNConstants.MODEL_ID, false, m -> toType(m).modelId, null); + protected final Parameter modelId = Parameter.stringParam( + KNNConstants.MODEL_ID, + false, + m -> toType(m).fieldType().getKnnMappingConfig().getModelId().orElse(null), + null + ); /** * knnMethodContext parameter allows a user to define their k-NN library index configuration. Defaults to an L2 @@ -137,7 +142,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder { false, () -> null, (n, c, o) -> KNNMethodContext.parse(o), - m -> toType(m).knnMethod + m -> toType(m).originalKNNMethodContext ).setSerializer(((b, n, v) -> { b.startObject(n); v.toXContent(b, ToXContent.EMPTY_PARAMS); @@ -164,35 +169,30 @@ public static class Builder extends ParametrizedFieldMapper.Builder { protected final Parameter> meta = Parameter.metaParam(); - protected String spaceType; - protected String m; - protected String efConstruction; - protected ModelDao modelDao; - protected Version indexCreatedVersion; - - public Builder(String name, ModelDao modelDao, Version indexCreatedVersion) { + // KNNMethodContext that allows us to properly configure a KNNVectorFieldMapper from another + // KNNVectorFieldMapper. To support our legacy field mapping, on parsing, if index.knn=true and no method is + // passed, we build a KNNMethodContext using the space type, ef_construction and m that are set in the index + // settings. However, for fieldmappers for merging, we need to be able to initialize one field mapper from + // another (see + // https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L98). + // The problem is that in this case, the settings are set to empty so we cannot properly resolve the KNNMethodContext. + // (see + // https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L130). + // While we could override the KNNMethodContext parameter initializer to set the knnMethodContext based on the + // constructed KNNMethodContext from the other field mapper, this can result in merge conflict/serialization + // exceptions. See + // (https://github.com/opensearch-project/OpenSearch/blob/2.16.0/server/src/main/java/org/opensearch/index/mapper/ParametrizedFieldMapper.java#L322-L324). + // So, what we do is pass in a "resolvedKNNMethodContext" that will either be null or be set via the merge builder + // constructor. A similar approach was taken for https://github.com/opendistro-for-elasticsearch/k-NN/issues/288 + private KNNMethodContext resolvedKNNMethodContext; + + public Builder(String name, ModelDao modelDao, Version indexCreatedVersion, KNNMethodContext resolvedKNNMethodContext) { super(name); this.modelDao = modelDao; this.indexCreatedVersion = indexCreatedVersion; - } - - /** - * This constructor is for legacy purposes. - * Checkout ODFE PR 288 - * - * @param name field name - * @param spaceType Spacetype of field - * @param m m value of field - * @param efConstruction efConstruction value of field - */ - public Builder(String name, String spaceType, String m, String efConstruction, Version indexCreatedVersion) { - super(name); - this.spaceType = spaceType; - this.m = m; - this.efConstruction = efConstruction; - this.indexCreatedVersion = indexCreatedVersion; + this.resolvedKNNMethodContext = resolvedKNNMethodContext; } @Override @@ -210,121 +210,117 @@ protected Explicit ignoreMalformed(BuilderContext context) { return KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED; } + private void validateFlatMapper() { + if (modelId.get() != null || knnMethodContext.get() != null) { + throw new IllegalArgumentException("Cannot set modelId or method parameters when index.knn setting is false"); + } + } + @Override public KNNVectorFieldMapper build(BuilderContext context) { - // Originally, a user would use index settings to set the spaceType, efConstruction and m hnsw - // parameters. Upon further review, it makes sense to set these parameters in the mapping of a - // particular field. However, because users migrating from older versions will still use the index - // settings to set these parameters, we will need to provide backwards compatibilty. In order to - // handle this, we first check if the mapping is set, and, if so use it. If not, we check if the model is - // set. If not, we fall back to the parameters set in the index settings. This means that if a user sets - // the mappings, setting the index settings will have no impact. - - final KNNMethodContext knnMethodContext = this.knnMethodContext.getValue(); - setDefaultSpaceType(knnMethodContext, vectorDataType.getValue()); - validateSpaceType(knnMethodContext, vectorDataType.getValue()); - validateDimensions(knnMethodContext, vectorDataType.getValue()); - validateEncoder(knnMethodContext, vectorDataType.getValue()); final MultiFields multiFieldsBuilder = this.multiFieldsBuilder.build(this, context); final CopyTo copyToBuilder = copyTo.build(); final Explicit ignoreMalformed = ignoreMalformed(context); final Map metaValue = meta.getValue(); - if (knnMethodContext != null) { - validateVectorDataType(knnMethodContext, vectorDataType.getValue()); - knnMethodContext.getMethodComponentContext().setIndexVersion(indexCreatedVersion); - final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( + // Index is being created from model + String modelIdAsString = this.modelId.get(); + if (modelIdAsString != null) { + return ModelFieldMapper.createFieldMapper( buildFullName(context), - metaValue, - dimension.getValue(), - knnMethodContext, - vectorDataType.getValue() - ); - if (knnMethodContext.getKnnEngine() == KNNEngine.LUCENE) { - log.debug(String.format(Locale.ROOT, "Use [LuceneFieldMapper] mapper for field [%s]", name)); - LuceneFieldMapper.CreateLuceneFieldMapperInput createLuceneFieldMapperInput = - LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() - .name(name) - .mappedFieldType(mappedFieldType) - .multiFields(multiFieldsBuilder) - .copyTo(copyToBuilder) - .ignoreMalformed(ignoreMalformed) - .stored(stored.get()) - .hasDocValues(hasDocValues.get()) - .vectorDataType(vectorDataType.getValue()) - .knnMethodContext(knnMethodContext) - .build(); - return new LuceneFieldMapper(createLuceneFieldMapperInput); - } - - return new MethodFieldMapper( name, - mappedFieldType, + metaValue, + vectorDataType.getValue(), + modelIdAsString, multiFieldsBuilder, copyToBuilder, ignoreMalformed, stored.get(), hasDocValues.get(), - knnMethodContext + modelDao, + indexCreatedVersion ); } - String modelIdAsString = this.modelId.get(); - if (modelIdAsString != null) { - // Because model information is stored in cluster metadata, we are unable to get it here. This is - // because to get the cluster metadata, you need access to the cluster state. Because this code is - // sometimes used to initialize the cluster state/update cluster state, we cannot get the state here - // safely. So, we are unable to validate the model. The model gets validated during ingestion. - - return new ModelFieldMapper( + // If the field mapper is using the legacy context and being constructed from another field mapper, + // the settings will be empty. See https://github.com/opendistro-for-elasticsearch/k-NN/issues/288. In this + // case, the input resolvedKNNMethodContext will be null and the settings wont exist (so flat mapper should + // be used). Otherwise, we need to check the setting. + boolean isResolvedNull = resolvedKNNMethodContext == null; + boolean isSettingPresent = KNNSettings.IS_KNN_INDEX_SETTING.exists(context.indexSettings()); + boolean isKnnSettingNotPresentOrFalse = !isSettingPresent || !KNNSettings.IS_KNN_INDEX_SETTING.get(context.indexSettings()); + if (isResolvedNull && isKnnSettingNotPresentOrFalse) { + validateFlatMapper(); + return FlatVectorFieldMapper.createFieldMapper( + buildFullName(context), name, - new KNNVectorFieldType(buildFullName(context), metaValue, -1, knnMethodContext, modelIdAsString), + metaValue, + vectorDataType.getValue(), + dimension.getValue(), multiFieldsBuilder, copyToBuilder, ignoreMalformed, stored.get(), hasDocValues.get(), - modelDao, - modelIdAsString, indexCreatedVersion ); } - // Build legacy - if (this.spaceType == null) { - this.spaceType = LegacyFieldMapper.getSpaceType(context.indexSettings(), vectorDataType.getValue()); - } - - if (this.m == null) { - this.m = LegacyFieldMapper.getM(context.indexSettings()); - } - - if (this.efConstruction == null) { - this.efConstruction = LegacyFieldMapper.getEfConstruction(context.indexSettings(), indexCreatedVersion); - } - - // Validates and throws exception if index.knn is set to true in the index settings - // using any VectorDataType (other than float, which is default) because we are using NMSLIB engine for LegacyFieldMapper - // and it only supports float VectorDataType - validateVectorDataTypeWithKnnIndexSetting(context.indexSettings().getAsBoolean(KNN_INDEX, false), vectorDataType); - - return new LegacyFieldMapper( - name, - new KNNVectorFieldType( + // See resolvedKNNMethodContext definition for explanation + if (isResolvedNull) { + resolvedKNNMethodContext = this.knnMethodContext.getValue(); + setDefaultSpaceType(resolvedKNNMethodContext, vectorDataType.getValue()); + validateSpaceType(resolvedKNNMethodContext, vectorDataType.getValue()); + validateDimensions(resolvedKNNMethodContext, vectorDataType.getValue()); + validateEncoder(resolvedKNNMethodContext, vectorDataType.getValue()); + } + + // If the knnMethodContext is null at this point, that means user built the index with the legacy k-NN + // settings to specify algo params. We need to convert this here to a KNNMethodContext so that we can + // properly configure the rest of the index + if (resolvedKNNMethodContext == null) { + resolvedKNNMethodContext = createKNNMethodContextFromLegacy(context, vectorDataType.getValue(), indexCreatedVersion); + } + + validateVectorDataType(resolvedKNNMethodContext, vectorDataType.getValue()); + resolvedKNNMethodContext.getMethodComponentContext().setIndexVersion(indexCreatedVersion); + if (resolvedKNNMethodContext.getKnnEngine() == KNNEngine.LUCENE) { + log.debug(String.format(Locale.ROOT, "Use [LuceneFieldMapper] mapper for field [%s]", name)); + LuceneFieldMapper.CreateLuceneFieldMapperInput createLuceneFieldMapperInput = LuceneFieldMapper.CreateLuceneFieldMapperInput + .builder() + .name(name) + .multiFields(multiFieldsBuilder) + .copyTo(copyToBuilder) + .ignoreMalformed(ignoreMalformed) + .stored(stored.getValue()) + .hasDocValues(hasDocValues.getValue()) + .vectorDataType(vectorDataType.getValue()) + .indexVersion(indexCreatedVersion) + .originalKnnMethodContext(knnMethodContext.get()) + .build(); + return LuceneFieldMapper.createFieldMapper( buildFullName(context), metaValue, - dimension.getValue(), vectorDataType.getValue(), - SpaceType.getSpace(spaceType) - ), + dimension.getValue(), + resolvedKNNMethodContext, + createLuceneFieldMapperInput + ); + } + + return MethodFieldMapper.createFieldMapper( + buildFullName(context), + name, + metaValue, + vectorDataType.getValue(), + dimension.getValue(), + resolvedKNNMethodContext, + knnMethodContext.get(), multiFieldsBuilder, copyToBuilder, ignoreMalformed, - stored.get(), - hasDocValues.get(), - spaceType, - m, - efConstruction, + stored.getValue(), + hasDocValues.getValue(), indexCreatedVersion ); } @@ -430,7 +426,7 @@ public TypeParser(Supplier modelDaoSupplier) { @Override public Mapper.Builder parse(String name, Map node, ParserContext parserContext) throws MapperParsingException { - Builder builder = new KNNVectorFieldMapper.Builder(name, modelDaoSupplier.get(), parserContext.indexVersionCreated()); + Builder builder = new KNNVectorFieldMapper.Builder(name, modelDaoSupplier.get(), parserContext.indexVersionCreated(), null); builder.parse(name, parserContext, node); // All parse(String name, Map node, ParserCont } // Dimension should not be null unless modelId is used - if (builder.dimension.getValue() == -1 && builder.modelId.get() == null) { + if (builder.dimension.getValue() == UNSET_MODEL_DIMENSION_IDENTIFIER && builder.modelId.get() == null) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Dimension value missing for vector: %s", name)); } @@ -452,18 +448,19 @@ public Mapper.Builder parse(String name, Map node, ParserCont } } + // We store the version of the index with the mapper as different version of Opensearch has different default + // values of KNN engine Algorithms hyperparameters. + protected Version indexCreatedVersion; protected Explicit ignoreMalformed; protected boolean stored; protected boolean hasDocValues; - protected Integer dimension; protected VectorDataType vectorDataType; protected ModelDao modelDao; - // These members map to parameters in the builder. They need to be declared in the abstract class due to the - // "toType" function used in the builder. So, when adding a parameter, it needs to be added here, but set in a - // subclass (if it is unique). - protected KNNMethodContext knnMethod; - protected String modelId; + // We need to ensure that the original KNNMethodContext as parsed is stored to initialize the + // Builder for serialization. So, we need to store it here. This is mainly to ensure that the legacy field mapper + // can use KNNMethodContext without messing up serialization on mapper merge + protected KNNMethodContext originalKNNMethodContext; public KNNVectorFieldMapper( String simpleName, @@ -473,16 +470,17 @@ public KNNVectorFieldMapper( Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - Version indexCreatedVersion + Version indexCreatedVersion, + KNNMethodContext originalKNNMethodContext ) { super(simpleName, mappedFieldType, multiFields, copyTo); this.ignoreMalformed = ignoreMalformed; this.stored = stored; this.hasDocValues = hasDocValues; - this.dimension = mappedFieldType.getDimension(); this.vectorDataType = mappedFieldType.getVectorDataType(); updateEngineStats(); this.indexCreatedVersion = indexCreatedVersion; + this.originalKNNMethodContext = originalKNNMethodContext; } public KNNVectorFieldMapper clone() { @@ -496,20 +494,7 @@ protected String contentType() { @Override protected void parseCreateField(ParseContext context) throws IOException { - parseCreateField( - context, - fieldType().getDimension(), - fieldType().getSpaceType(), - getMethodComponentContext(fieldType().getKnnMethodContext()), - fieldType().getVectorDataType() - ); - } - - private MethodComponentContext getMethodComponentContext(KNNMethodContext knnMethodContext) { - if (Objects.isNull(knnMethodContext)) { - return null; - } - return knnMethodContext.getMethodComponentContext(); + parseCreateField(context, fieldType().getKnnMappingConfig().getDimension(), fieldType().getVectorDataType()); } /** @@ -544,17 +529,37 @@ protected List getFieldsForByteVector(final byte[] array, final FieldType return fields; } - protected void parseCreateField( - ParseContext context, - int dimension, - SpaceType spaceType, - MethodComponentContext methodComponentContext, - VectorDataType vectorDataType - ) throws IOException { - + /** + * Validation checks before parsing of doc begins + */ + protected void validatePreparse() { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); - spaceType.validateVectorDataType(vectorDataType); + } + + /** + * Getter for vector validator after vector parsing + * + * @return VectorValidator + */ + protected abstract VectorValidator getVectorValidator(); + + /** + * Getter for per dimension validator during vector parsing + * + * @return PerDimensionValidator + */ + protected abstract PerDimensionValidator getPerDimensionValidator(); + + /** + * Getter for per dimension processor during vector parsing + * + * @return PerDimensionProcessor + */ + protected abstract PerDimensionProcessor getPerDimensionProcessor(); + + protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { + validatePreparse(); if (VectorDataType.BINARY == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); @@ -563,7 +568,7 @@ protected void parseCreateField( return; } final byte[] array = bytesArrayOptional.get(); - spaceType.validateVector(array); + getVectorValidator().validateVector(array); context.doc().addAll(getFieldsForByteVector(array, fieldType)); } else if (VectorDataType.BYTE == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); @@ -572,16 +577,16 @@ protected void parseCreateField( return; } final byte[] array = bytesArrayOptional.get(); - spaceType.validateVector(array); + getVectorValidator().validateVector(array); context.doc().addAll(getFieldsForByteVector(array, fieldType)); } else if (VectorDataType.FLOAT == vectorDataType) { - Optional floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext); + Optional floatsArrayOptional = getFloatsFromContext(context, dimension); if (floatsArrayOptional.isEmpty()) { return; } final float[] array = floatsArrayOptional.get(); - spaceType.validateVector(array); + getVectorValidator().validateVector(array); context.doc().addAll(getFieldsForFloatVector(array, fieldType)); } else { throw new IllegalArgumentException( @@ -592,80 +597,28 @@ protected void parseCreateField( context.path().remove(); } - // Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16" - protected boolean isFaissSQfp16(MethodComponentContext methodComponentContext) { - if (Objects.isNull(methodComponentContext)) { - return false; - } - - if (methodComponentContext.getParameters().size() == 0) { - return false; - } - - Map methodComponentParams = methodComponentContext.getParameters(); - - // The method component parameters should have an encoder - if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) { - return false; - } - - // Validate if the object is of type MethodComponentContext before casting it later - if (!(methodComponentParams.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) { - return false; - } - - MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER); - - // returns true if encoder name is "sq" and type is "fp16" - return ENCODER_SQ.equals(encoderMethodComponentContext.getName()) - && FAISS_SQ_ENCODER_FP16.equals( - encoderMethodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) - ); - - } - - // Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index - // using "sq" encoder of type "fp16". - protected boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) { - if (Objects.nonNull(methodComponentContext)) { - return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false); - } - return false; - } - - void validateIfCircuitBreakerIsNotTriggered() { - if (KNNSettings.isCircuitBreakerTriggered()) { - throw new KnnCircuitBreakerException( - "Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the cluster is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status." - ); - } - } - - void validateIfKNNPluginEnabled() { - if (!KNNSettings.isKNNPluginEnabled()) { - throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled setting to true"); - } - } - // Returns an optional array of byte values where each value in the vector is parsed as a float and validated // if it is a finite number without any decimals and within the byte range of [-128 to 127]. Optional getBytesFromContext(ParseContext context, int dimension, VectorDataType dataType) throws IOException { context.path().add(simpleName()); + PerDimensionValidator perDimensionValidator = getPerDimensionValidator(); + PerDimensionProcessor perDimensionProcessor = getPerDimensionProcessor(); + ArrayList vector = new ArrayList<>(); XContentParser.Token token = context.parser().currentToken(); if (token == XContentParser.Token.START_ARRAY) { token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { - float value = context.parser().floatValue(); - validateByteVectorValue(value, dataType); + float value = perDimensionProcessor.processByte(context.parser().floatValue()); + perDimensionValidator.validateByte(value); vector.add((byte) value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { - float value = context.parser().floatValue(); - validateByteVectorValue(value, dataType); + float value = perDimensionProcessor.processByte(context.parser().floatValue()); + perDimensionValidator.validateByte(value); vector.add((byte) value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { @@ -681,21 +634,11 @@ Optional getBytesFromContext(ParseContext context, int dimension, Vector return Optional.of(array); } - Optional getFloatsFromContext(ParseContext context, int dimension, MethodComponentContext methodComponentContext) - throws IOException { + Optional getFloatsFromContext(ParseContext context, int dimension) throws IOException { context.path().add(simpleName()); - // Returns an optional array of float values where each value in the vector is parsed as a float and validated - // if it is a finite number and within the fp16 range of [-65504 to 65504] by default if Faiss encoder is SQ and type is 'fp16'. - // If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be - // clipped to FP16 range. - boolean isFaissSQfp16Flag = isFaissSQfp16(methodComponentContext); - boolean clipVectorValueToFP16RangeFlag = false; - if (isFaissSQfp16Flag) { - clipVectorValueToFP16RangeFlag = isFaissSQClipToFP16RangeEnabled( - (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) - ); - } + PerDimensionValidator perDimensionValidator = getPerDimensionValidator(); + PerDimensionProcessor perDimensionProcessor = getPerDimensionProcessor(); ArrayList vector = new ArrayList<>(); XContentParser.Token token = context.parser().currentToken(); @@ -703,31 +646,14 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth if (token == XContentParser.Token.START_ARRAY) { token = context.parser().nextToken(); while (token != XContentParser.Token.END_ARRAY) { - value = context.parser().floatValue(); - if (isFaissSQfp16Flag) { - if (clipVectorValueToFP16RangeFlag) { - value = clipVectorValueToFP16Range(value); - } else { - validateFP16VectorValue(value); - } - } else { - validateFloatVectorValue(value); - } - + value = perDimensionProcessor.process(context.parser().floatValue()); + perDimensionValidator.validate(value); vector.add(value); token = context.parser().nextToken(); } } else if (token == XContentParser.Token.VALUE_NUMBER) { - value = context.parser().floatValue(); - if (isFaissSQfp16Flag) { - if (clipVectorValueToFP16RangeFlag) { - value = clipVectorValueToFP16Range(value); - } else { - validateFP16VectorValue(value); - } - } else { - validateFloatVectorValue(value); - } + value = perDimensionProcessor.process(context.parser().floatValue()); + perDimensionValidator.validate(value); vector.add(value); context.parser().nextToken(); } else if (token == XContentParser.Token.VALUE_NULL) { @@ -746,7 +672,12 @@ Optional getFloatsFromContext(ParseContext context, int dimension, Meth @Override public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new KNNVectorFieldMapper.Builder(simpleName(), modelDao, indexCreatedVersion).init(this); + return new KNNVectorFieldMapper.Builder( + simpleName(), + modelDao, + indexCreatedVersion, + fieldType().getKnnMappingConfig().getKnnMethodContext().orElse(null) + ).init(this); } @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 2adbbb6953..0caaf80ab0 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -13,30 +13,45 @@ import lombok.AccessLevel; import lombok.NoArgsConstructor; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.util.BytesRef; -import org.opensearch.index.mapper.ParametrizedFieldMapper; +import org.opensearch.Version; +import org.opensearch.common.settings.Settings; +import org.opensearch.index.mapper.Mapper; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.KnnCircuitBreakerException; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.indices.ModelDao; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelUtil; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.util.IndexHyperParametersUtil; import java.util.Arrays; import java.util.Locale; +import java.util.Map; +import java.util.Objects; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; +import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; @@ -44,19 +59,10 @@ /** * Utility class for KNNVectorFieldMapper */ +@Log4j2 @NoArgsConstructor(access = AccessLevel.PRIVATE) public class KNNVectorFieldMapperUtil { - private static ModelDao modelDao; - - /** - * Initializes static instance variables - * @param modelDao ModelDao object - */ - public static void initialize(final ModelDao modelDao) { - KNNVectorFieldMapperUtil.modelDao = modelDao; - } - /** * Validate the float vector value and throw exception if it is not a number or not in the finite range * or is not within the FP16 range of [-65504 to 65504]. @@ -150,35 +156,6 @@ public static void validateVectorDataType(KNNMethodContext methodContext, Vector throw new IllegalArgumentException("This line should not be reached"); } - /** - * Validates and throws exception if index.knn is set to true in the index settings - * using any VectorDataType (other than float, which is default) because we are using NMSLIB engine - * for LegacyFieldMapper, and it only supports float VectorDataType - * - * @param knnIndexSetting index.knn setting in the index settings - * @param vectorDataType VectorDataType Parameter - */ - public static void validateVectorDataTypeWithKnnIndexSetting( - boolean knnIndexSetting, - ParametrizedFieldMapper.Parameter vectorDataType - ) { - - if (VectorDataType.FLOAT == vectorDataType.getValue()) { - return; - } - if (knnIndexSetting) { - throw new IllegalArgumentException( - String.format( - Locale.ROOT, - "[%s] field with value [%s] is not supported for [%s] engine", - VECTOR_DATA_TYPE_FIELD, - vectorDataType.getValue().getValue(), - NMSLIB_NAME - ) - ); - } - } - /** * @param knnEngine KNNEngine * @return DocValues FieldType of type Binary @@ -237,37 +214,172 @@ public static Object deserializeStoredVector(BytesRef storedVector, VectorDataTy * @return expected vector length */ public static int getExpectedVectorLength(final KNNVectorFieldType knnVectorFieldType) { - int expectedDimensions = knnVectorFieldType.getDimension(); - if (isModelBasedIndex(expectedDimensions)) { - ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); - expectedDimensions = modelMetadata.getDimension(); - } + int expectedDimensions = knnVectorFieldType.getKnnMappingConfig().getDimension(); return VectorDataType.BINARY == knnVectorFieldType.getVectorDataType() ? expectedDimensions / 8 : expectedDimensions; } - private static boolean isModelBasedIndex(int expectedDimensions) { - return expectedDimensions == -1; + /** + * Validate if the circuit breaker is triggered + */ + static void validateIfCircuitBreakerIsNotTriggered() { + if (KNNSettings.isCircuitBreakerTriggered()) { + throw new KnnCircuitBreakerException( + "Parsing the created knn vector fields prior to indexing has failed as the circuit breaker triggered. This indicates that the cluster is low on memory resources and cannot index more documents at the moment. Check _plugins/_knn/stats for the circuit breaker status." + ); + } } /** - * Returns the model metadata for a specified knn vector field + * Validate if plugin is enabled + */ + static void validateIfKNNPluginEnabled() { + if (!KNNSettings.isKNNPluginEnabled()) { + throw new IllegalStateException("KNN plugin is disabled. To enable update knn.plugin.enabled setting to true"); + } + } + + private static SpaceType getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) { + String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey()); + if (spaceType == null) { + spaceType = VectorDataType.BINARY == vectorDataType + ? KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY + : KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE; + log.info( + String.format( + "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", + METHOD_PARAMETER_SPACE_TYPE, + spaceType + ) + ); + } + return SpaceType.getSpace(spaceType); + } + + private static int getM(Settings indexSettings) { + String m = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING.getKey()); + if (m == null) { + log.info( + String.format( + "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", + HNSW_ALGO_M, + KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M + ) + ); + return KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M; + } + return Integer.parseInt(m); + } + + private static int getEfConstruction(Settings indexSettings, Version indexVersion) { + final String efConstruction = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING.getKey()); + if (efConstruction == null) { + final int defaultEFConstructionValue = IndexHyperParametersUtil.getHNSWEFConstructionValue(indexVersion); + log.info( + String.format( + "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. " + + "Picking up default value for the index =%s", + HNSW_ALGO_EF_CONSTRUCTION, + defaultEFConstructionValue + ) + ); + return defaultEFConstructionValue; + } + return Integer.parseInt(efConstruction); + } + + /** + * Verify mapping and return true if it is a "faiss" Index using "sq" encoder of type "fp16" * - * @param knnVectorField knn vector field - * @return the model metadata from knnVectorField + * @param methodComponentContext MethodComponentContext + * @return true if it is a "faiss" Index using "sq" encoder of type "fp16" */ - private static ModelMetadata getModelMetadataForField(final KNNVectorFieldType knnVectorField) { - String modelId = knnVectorField.getModelId(); + static boolean isFaissSQfp16(MethodComponentContext methodComponentContext) { + if (Objects.isNull(methodComponentContext)) { + return false; + } - if (modelId == null) { - throw new IllegalArgumentException( - String.format("Field '%s' does not have model.", knnVectorField.getKnnMethodContext().getMethodComponentContext().getName()) + if (methodComponentContext.getParameters().size() == 0) { + return false; + } + + Map methodComponentParams = methodComponentContext.getParameters(); + + // The method component parameters should have an encoder + if (!methodComponentParams.containsKey(METHOD_ENCODER_PARAMETER)) { + return false; + } + + // Validate if the object is of type MethodComponentContext before casting it later + if (!(methodComponentParams.get(METHOD_ENCODER_PARAMETER) instanceof MethodComponentContext)) { + return false; + } + + MethodComponentContext encoderMethodComponentContext = (MethodComponentContext) methodComponentParams.get(METHOD_ENCODER_PARAMETER); + + // returns true if encoder name is "sq" and type is "fp16" + return ENCODER_SQ.equals(encoderMethodComponentContext.getName()) + && FAISS_SQ_ENCODER_FP16.equals( + encoderMethodComponentContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) ); + + } + + /** + * Verify mapping and return the value of "clip" parameter(default false) for a "faiss" Index + * using "sq" encoder of type "fp16". + * + * @param methodComponentContext MethodComponentContext + * @return boolean value of "clip" parameter + */ + static boolean isFaissSQClipToFP16RangeEnabled(MethodComponentContext methodComponentContext) { + if (Objects.nonNull(methodComponentContext)) { + return (boolean) methodComponentContext.getParameters().getOrDefault(FAISS_SQ_CLIP, false); } + return false; + } - ModelMetadata modelMetadata = modelDao.getMetadata(modelId); - if (!ModelUtil.isModelCreated(modelMetadata)) { - throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId)); + /** + * Extract MethodComponentContext from KNNMethodContext + * + * @param knnMethodContext KNNMethodContext + * @return MethodComponentContext + */ + static MethodComponentContext getMethodComponentContext(KNNMethodContext knnMethodContext) { + if (Objects.isNull(knnMethodContext)) { + return null; } - return modelMetadata; + return knnMethodContext.getMethodComponentContext(); + } + + static KNNMethodContext createKNNMethodContextFromLegacy( + Mapper.BuilderContext context, + VectorDataType vectorDataType, + Version indexCreatedVersion + ) { + if (VectorDataType.FLOAT != vectorDataType) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "[%s] field with value [%s] is not supported for [%s] engine", + VECTOR_DATA_TYPE_FIELD, + vectorDataType.getValue(), + NMSLIB_NAME + ) + ); + } + + return new KNNMethodContext( + KNNEngine.NMSLIB, + KNNVectorFieldMapperUtil.getSpaceType(context.indexSettings(), vectorDataType), + new MethodComponentContext( + METHOD_HNSW, + Map.of( + METHOD_PARAMETER_M, + KNNVectorFieldMapperUtil.getM(context.indexSettings()), + METHOD_PARAMETER_EF_CONSTRUCTION, + KNNVectorFieldMapperUtil.getEfConstruction(context.indexSettings(), indexCreatedVersion) + ) + ) + ); } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index 8c3815c5f9..0fbc569f77 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -9,7 +9,6 @@ import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; -import org.opensearch.common.Nullable; import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.TextSearchInfo; @@ -17,9 +16,7 @@ import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; import org.opensearch.knn.index.KNNVectorIndexFieldData; -import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.search.aggregations.support.CoreValuesSourceType; import org.opensearch.search.lookup.SearchLookup; @@ -27,7 +24,6 @@ import java.util.Map; import java.util.function.Supplier; -import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.deserializeStoredVector; /** @@ -35,49 +31,21 @@ */ @Getter public class KNNVectorFieldType extends MappedFieldType { - int dimension; - String modelId; - KNNMethodContext knnMethodContext; + KNNMappingConfig knnMappingConfig; VectorDataType vectorDataType; - SpaceType spaceType; - public KNNVectorFieldType(String name, Map meta, int dimension, VectorDataType vectorDataType, SpaceType spaceType) { - this(name, meta, dimension, null, null, vectorDataType, spaceType); - } - - public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext) { - this(name, meta, dimension, knnMethodContext, null, DEFAULT_VECTOR_DATA_TYPE_FIELD, knnMethodContext.getSpaceType()); - } - - public KNNVectorFieldType(String name, Map meta, int dimension, KNNMethodContext knnMethodContext, String modelId) { - this(name, meta, dimension, knnMethodContext, modelId, DEFAULT_VECTOR_DATA_TYPE_FIELD, null); - } - - public KNNVectorFieldType( - String name, - Map meta, - int dimension, - KNNMethodContext knnMethodContext, - VectorDataType vectorDataType - ) { - this(name, meta, dimension, knnMethodContext, null, vectorDataType, knnMethodContext.getSpaceType()); - } - - public KNNVectorFieldType( - String name, - Map meta, - int dimension, - @Nullable KNNMethodContext knnMethodContext, - @Nullable String modelId, - VectorDataType vectorDataType, - @Nullable SpaceType spaceType - ) { - super(name, false, false, true, TextSearchInfo.NONE, meta); - this.dimension = dimension; - this.modelId = modelId; - this.knnMethodContext = knnMethodContext; + /** + * Constructor for KNNVectorFieldType. + * + * @param name name of the field + * @param metadata metadata of the field + * @param vectorDataType data type of the vector + * @param annConfig configuration context for the ANN index + */ + public KNNVectorFieldType(String name, Map metadata, VectorDataType vectorDataType, KNNMappingConfig annConfig) { + super(name, false, false, true, TextSearchInfo.NONE, metadata); this.vectorDataType = vectorDataType; - this.spaceType = spaceType; + this.knnMappingConfig = annConfig; } @Override diff --git a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java deleted file mode 100644 index cf5ec933a1..0000000000 --- a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.mapper; - -import lombok.extern.log4j.Log4j2; -import org.apache.lucene.document.FieldType; -import org.opensearch.Version; -import org.opensearch.common.Explicit; -import org.opensearch.common.settings.Settings; -import org.opensearch.index.mapper.ParametrizedFieldMapper; -import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.util.IndexHyperParametersUtil; -import org.opensearch.knn.index.engine.KNNEngine; - -import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION; -import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; - -/** - * Field mapper for original implementation. It defaults to using nmslib as the engine and retrieves parameters from index settings. - * - * Example of this mapper output: - * - * { - * "type": "knn_vector", - * "dimension": 128 - * } - */ -@Log4j2 -public class LegacyFieldMapper extends KNNVectorFieldMapper { - - protected String spaceType; - protected String m; - protected String efConstruction; - - LegacyFieldMapper( - String simpleName, - KNNVectorFieldType mappedFieldType, - MultiFields multiFields, - CopyTo copyTo, - Explicit ignoreMalformed, - boolean stored, - boolean hasDocValues, - String spaceType, - String m, - String efConstruction, - Version indexCreatedVersion - ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion); - - this.spaceType = spaceType; - this.m = m; - this.efConstruction = efConstruction; - - this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); - - this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension)); - this.fieldType.putAttribute(SPACE_TYPE, spaceType); - this.fieldType.putAttribute(KNN_ENGINE, KNNEngine.NMSLIB.getName()); - - // These are extra just for legacy - this.fieldType.putAttribute(HNSW_ALGO_M, m); - this.fieldType.putAttribute(HNSW_ALGO_EF_CONSTRUCTION, efConstruction); - - this.fieldType.freeze(); - } - - @Override - public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new KNNVectorFieldMapper.Builder(simpleName(), this.spaceType, this.m, this.efConstruction, this.indexCreatedVersion).init( - this - ); - } - - static String getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) { - String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey()); - if (spaceType == null) { - spaceType = VectorDataType.BINARY == vectorDataType - ? KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE_FOR_BINARY - : KNNSettings.INDEX_KNN_DEFAULT_SPACE_TYPE; - log.info( - String.format( - "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", - METHOD_PARAMETER_SPACE_TYPE, - spaceType - ) - ); - } - return spaceType; - } - - static String getM(Settings indexSettings) { - String m = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_M_SETTING.getKey()); - if (m == null) { - log.info( - String.format( - "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. Setting the setting to the default value=%s", - HNSW_ALGO_M, - KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M - ) - ); - return String.valueOf(KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_M); - } - return m; - } - - static String getEfConstruction(Settings indexSettings, Version indexVersion) { - final String efConstruction = indexSettings.get(KNNSettings.INDEX_KNN_ALGO_PARAM_EF_CONSTRUCTION_SETTING.getKey()); - if (efConstruction == null) { - final String defaultEFConstructionValue = String.valueOf(IndexHyperParametersUtil.getHNSWEFConstructionValue(indexVersion)); - log.info( - String.format( - "[KNN] The setting \"%s\" was not set for the index. Likely caused by recent version upgrade. " - + "Picking up default value for the index =%s", - HNSW_ALGO_EF_CONSTRUCTION, - defaultEFConstructionValue - ) - ); - return defaultEFConstructionValue; - } - return efConstruction; - } -} diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index c82afb9e72..665c35f6e3 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -8,6 +8,9 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Map; +import java.util.Optional; + import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NonNull; @@ -16,11 +19,12 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.VectorSimilarityFunction; +import org.opensearch.Version; import org.opensearch.common.Explicit; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForByteVector; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.createStoredFieldForFloatVector; @@ -35,44 +39,76 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper { private final FieldType vectorFieldType; private final VectorDataType vectorDataType; - LuceneFieldMapper(final CreateLuceneFieldMapperInput input) { + private PerDimensionProcessor perDimensionProcessor; + private PerDimensionValidator perDimensionValidator; + private VectorValidator vectorValidator; + + static LuceneFieldMapper createFieldMapper( + String fullname, + Map metaValue, + VectorDataType vectorDataType, + Integer dimension, + KNNMethodContext knnMethodContext, + CreateLuceneFieldMapperInput createLuceneFieldMapperInput + ) { + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, new KNNMappingConfig() { + @Override + public Optional getKnnMethodContext() { + return Optional.of(knnMethodContext); + } + + @Override + public int getDimension() { + return dimension; + } + }); + + return new LuceneFieldMapper(mappedFieldType, createLuceneFieldMapperInput); + } + + private LuceneFieldMapper(final KNNVectorFieldType mappedFieldType, final CreateLuceneFieldMapperInput input) { super( input.getName(), - input.getMappedFieldType(), + mappedFieldType, input.getMultiFields(), input.getCopyTo(), input.getIgnoreMalformed(), input.isStored(), input.isHasDocValues(), - input.getKnnMethodContext().getMethodComponentContext().getIndexVersion() + input.getIndexVersion(), + mappedFieldType.knnMappingConfig.getKnnMethodContext().orElse(null) ); - + KNNMappingConfig knnMappingConfig = mappedFieldType.getKnnMappingConfig(); + KNNMethodContext knnMethodContext = knnMappingConfig.getKnnMethodContext() + .orElseThrow(() -> new IllegalArgumentException("KNN method context is missing")); vectorDataType = input.getVectorDataType(); - this.knnMethod = input.getKnnMethodContext(); - final VectorSimilarityFunction vectorSimilarityFunction = this.knnMethod.getSpaceType() + + final VectorSimilarityFunction vectorSimilarityFunction = knnMethodContext.getSpaceType() .getKnnVectorSimilarityFunction() .getVectorSimilarityFunction(); - final int dimension = input.getMappedFieldType().getDimension(); - if (dimension > KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE)) { + if (knnMappingConfig.getDimension() > KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE)) { throw new IllegalArgumentException( String.format( Locale.ROOT, "Dimension value cannot be greater than [%s] but got [%s] for vector [%s]", KNNEngine.getMaxDimensionByEngine(KNNEngine.LUCENE), - dimension, + knnMappingConfig.getDimension(), input.getName() ) ); } - this.fieldType = vectorDataType.createKnnVectorFieldType(dimension, vectorSimilarityFunction); + this.fieldType = vectorDataType.createKnnVectorFieldType(knnMappingConfig.getDimension(), vectorSimilarityFunction); if (this.hasDocValues) { - this.vectorFieldType = buildDocValuesFieldType(this.knnMethod.getKnnEngine()); + this.vectorFieldType = buildDocValuesFieldType(knnMethodContext.getKnnEngine()); } else { this.vectorFieldType = null; } + + initValidatorsAndProcessors(knnMethodContext); + knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType); } @Override @@ -105,6 +141,36 @@ protected List getFieldsForByteVector(final byte[] array, final FieldType return fieldsToBeAdded; } + private void initValidatorsAndProcessors(KNNMethodContext knnMethodContext) { + this.vectorValidator = new SpaceVectorValidator(knnMethodContext.getSpaceType()); + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + if (VectorDataType.BINARY == vectorDataType) { + this.perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + return; + } + + if (VectorDataType.BYTE == vectorDataType) { + this.perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + return; + } + this.perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; + } + + @Override + protected VectorValidator getVectorValidator() { + return vectorValidator; + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator; + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + return perDimensionProcessor; + } + @Override void updateEngineStats() { KNNEngine.LUCENE.setInitialized(true); @@ -117,8 +183,6 @@ static class CreateLuceneFieldMapperInput { @NonNull String name; @NonNull - KNNVectorFieldType mappedFieldType; - @NonNull MultiFields multiFields; @NonNull CopyTo copyTo; @@ -127,7 +191,7 @@ static class CreateLuceneFieldMapperInput { boolean stored; boolean hasDocValues; VectorDataType vectorDataType; - @NonNull - KNNMethodContext knnMethodContext; + Version indexVersion; + KNNMethodContext originalKnnMethodContext; } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index b15ab14894..7a69c941b8 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -6,37 +6,64 @@ package org.opensearch.knn.index.mapper; import org.apache.lucene.document.FieldType; +import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; import java.io.IOException; import java.util.Map; +import java.util.Optional; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.getMethodComponentContext; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQClipToFP16RangeEnabled; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQfp16; /** * Field mapper for method definition in mapping */ public class MethodFieldMapper extends KNNVectorFieldMapper { - MethodFieldMapper( + private PerDimensionProcessor perDimensionProcessor; + private PerDimensionValidator perDimensionValidator; + private VectorValidator vectorValidator; + + public static MethodFieldMapper createFieldMapper( + String fullname, String simpleName, - KNNVectorFieldType mappedFieldType, + Map metaValue, + VectorDataType vectorDataType, + Integer dimension, + KNNMethodContext knnMethodContext, + KNNMethodContext originalKNNMethodContext, MultiFields multiFields, CopyTo copyTo, Explicit ignoreMalformed, boolean stored, boolean hasDocValues, - KNNMethodContext knnMethodContext + Version indexCreatedVersion ) { + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, new KNNMappingConfig() { + @Override + public Optional getKnnMethodContext() { + return Optional.of(knnMethodContext); + } - super( + @Override + public int getDimension() { + return dimension; + } + }); + return new MethodFieldMapper( simpleName, mappedFieldType, multiFields, @@ -44,14 +71,40 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { ignoreMalformed, stored, hasDocValues, - knnMethodContext.getMethodComponentContext().getIndexVersion() + indexCreatedVersion, + originalKNNMethodContext ); + } - this.knnMethod = knnMethodContext; + private MethodFieldMapper( + String simpleName, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + Version indexVerision, + KNNMethodContext originalKNNMethodContext + ) { + super( + simpleName, + mappedFieldType, + multiFields, + copyTo, + ignoreMalformed, + stored, + hasDocValues, + indexVerision, + originalKNNMethodContext + ); + KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig(); + KNNMethodContext knnMethodContext = annConfig.getKnnMethodContext() + .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); - this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension)); + this.fieldType.putAttribute(DIMENSION, String.valueOf(annConfig.getDimension())); this.fieldType.putAttribute(SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); this.fieldType.putAttribute(VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); @@ -66,5 +119,57 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { } this.fieldType.freeze(); + initValidatorsAndProcessors(knnMethodContext); + knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType); + } + + private void initValidatorsAndProcessors(KNNMethodContext knnMethodContext) { + this.vectorValidator = new SpaceVectorValidator(knnMethodContext.getSpaceType()); + + if (VectorDataType.BINARY == vectorDataType) { + this.perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + if (VectorDataType.BYTE == vectorDataType) { + this.perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + MethodComponentContext methodComponentContext = getMethodComponentContext(knnMethodContext); + if (!isFaissSQfp16(methodComponentContext)) { + // Normal float and byte processor + this.perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + this.perDimensionValidator = PerDimensionValidator.DEFAULT_FP16_VALIDATOR; + + if (!isFaissSQClipToFP16RangeEnabled( + (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) + )) { + this.perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + this.perDimensionProcessor = PerDimensionProcessor.CLIP_TO_FP16_PROCESSOR; + } + + @Override + protected VectorValidator getVectorValidator() { + return vectorValidator; + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + return perDimensionValidator; + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + return perDimensionProcessor; } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index adaaef28e6..a21a01a5dc 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -9,65 +9,198 @@ import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelUtil; import java.io.IOException; +import java.util.Map; +import java.util.Optional; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQClipToFP16RangeEnabled; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.isFaissSQfp16; /** * Field mapper for model in mapping */ public class ModelFieldMapper extends KNNVectorFieldMapper { - ModelFieldMapper( + // If the dimension has not yet been set because we do not have access to model metadata, it will be -1 + public static final int UNSET_MODEL_DIMENSION_IDENTIFIER = -1; + + private PerDimensionProcessor perDimensionProcessor; + private PerDimensionValidator perDimensionValidator; + private VectorValidator vectorValidator; + + private final String modelId; + + public static ModelFieldMapper createFieldMapper( + String fullname, String simpleName, - KNNVectorFieldType mappedFieldType, + Map metaValue, + VectorDataType vectorDataType, + String modelId, MultiFields multiFields, CopyTo copyTo, Explicit ignoreMalformed, boolean stored, boolean hasDocValues, ModelDao modelDao, - String modelId, Version indexCreatedVersion ) { - super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion); - this.modelId = modelId; + final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, new KNNMappingConfig() { + @Override + public Optional getModelId() { + return Optional.of(modelId); + } + + @Override + public int getDimension() { + return getModelMetadata(modelDao, modelId).getDimension(); + } + }); + return new ModelFieldMapper( + simpleName, + mappedFieldType, + multiFields, + copyTo, + ignoreMalformed, + stored, + hasDocValues, + modelDao, + indexCreatedVersion + ); + } + + private ModelFieldMapper( + String simpleName, + KNNVectorFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + Explicit ignoreMalformed, + boolean stored, + boolean hasDocValues, + ModelDao modelDao, + Version indexCreatedVersion + ) { + super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, null); + KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig(); + modelId = annConfig.getModelId().orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); this.modelDao = modelDao; + // For the model field mapper, we cannot validate the model during index creation due to + // an issue with reading cluster state during mapper creation. So, we need to validate the + // model when ingestion starts. We do this as lazily as we can + this.perDimensionProcessor = null; + this.perDimensionValidator = null; + this.vectorValidator = null; + this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); this.fieldType.putAttribute(MODEL_ID, modelId); this.fieldType.freeze(); } + @Override + protected VectorValidator getVectorValidator() { + initVectorValidator(); + return vectorValidator; + } + + @Override + protected PerDimensionValidator getPerDimensionValidator() { + initPerDimensionValidator(); + return perDimensionValidator; + } + + @Override + protected PerDimensionProcessor getPerDimensionProcessor() { + initPerDimensionProcessor(); + return perDimensionProcessor; + } + + private void initVectorValidator() { + if (vectorValidator != null) { + return; + } + ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); + vectorValidator = new SpaceVectorValidator(modelMetadata.getSpaceType()); + } + + private void initPerDimensionValidator() { + if (perDimensionValidator != null) { + return; + } + ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); + MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); + VectorDataType dataType = modelMetadata.getVectorDataType(); + + if (VectorDataType.BINARY == dataType) { + perDimensionValidator = PerDimensionValidator.DEFAULT_BIT_VALIDATOR; + return; + } + + if (VectorDataType.BYTE == dataType) { + perDimensionValidator = PerDimensionValidator.DEFAULT_BYTE_VALIDATOR; + return; + } + + if (!isFaissSQfp16(methodComponentContext)) { + perDimensionValidator = PerDimensionValidator.DEFAULT_FLOAT_VALIDATOR; + return; + } + + perDimensionValidator = PerDimensionValidator.DEFAULT_FP16_VALIDATOR; + } + + private void initPerDimensionProcessor() { + if (perDimensionProcessor != null) { + return; + } + ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); + MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); + VectorDataType dataType = modelMetadata.getVectorDataType(); + + if (VectorDataType.BINARY == dataType) { + perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + if (VectorDataType.BYTE == dataType) { + perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + if (!isFaissSQfp16(methodComponentContext)) { + perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + + if (!isFaissSQClipToFP16RangeEnabled( + (MethodComponentContext) methodComponentContext.getParameters().get(METHOD_ENCODER_PARAMETER) + )) { + perDimensionProcessor = PerDimensionProcessor.NOOP_PROCESSOR; + return; + } + perDimensionProcessor = PerDimensionProcessor.CLIP_TO_FP16_PROCESSOR; + } + @Override protected void parseCreateField(ParseContext context) throws IOException { - // For the model field mapper, we cannot validate the model during index creation due to - // an issue with reading cluster state during mapper creation. So, we need to validate the - // model when ingestion starts. - ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId); + validatePreparse(); + ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); + parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType()); + } + private static ModelMetadata getModelMetadata(ModelDao modelDao, String modelId) { + ModelMetadata modelMetadata = modelDao.getMetadata(modelId); if (!ModelUtil.isModelCreated(modelMetadata)) { - throw new IllegalStateException( - String.format( - "Model \"%s\" from %s's mapping is not created. Because the \"%s\" parameter is not updatable, this index will need to be recreated with a valid model.", - modelId, - context.mapperService().index().getName(), - MODEL_ID - ) - ); + throw new IllegalStateException(String.format("Model ID '%s' is not created.", modelId)); } - - parseCreateField( - context, - modelMetadata.getDimension(), - modelMetadata.getSpaceType(), - modelMetadata.getMethodComponentContext(), - modelMetadata.getVectorDataType() - ); + return modelMetadata; } } diff --git a/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java new file mode 100644 index 0000000000..21139f2ad4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionProcessor.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; + +/** + * Process values per dimension. Good to have if we want to do some kind of cleanup on data as it is coming in. + */ +public interface PerDimensionProcessor { + + /** + * Process float value per dimension. + * + * @param value value to process + * @return processed value + */ + default float process(float value) { + return value; + } + + /** + * Process byte as float value per dimension. + * + * @param value value to process + * @return processed value + */ + default float processByte(float value) { + return value; + } + + PerDimensionProcessor NOOP_PROCESSOR = new PerDimensionProcessor() { + }; + + // If the encoder parameter, "clip" is set to True, if the vector value is outside the FP16 range then it will be + // clipped to FP16 range. + PerDimensionProcessor CLIP_TO_FP16_PROCESSOR = new PerDimensionProcessor() { + @Override + public float process(float value) { + return clipVectorValueToFP16Range(value); + } + + @Override + public float processByte(float value) { + throw new IllegalStateException("CLIP_TO_FP16_PROCESSOR should not be called with byte type"); + } + }; +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java new file mode 100644 index 0000000000..2ca0761c02 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/PerDimensionValidator.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import org.opensearch.knn.index.VectorDataType; + +import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; +import static org.opensearch.knn.common.KNNValidationUtil.validateFloatVectorValue; +import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; + +/** + * Validates per dimension fields + */ +public interface PerDimensionValidator { + /** + * Validates the given float is valid for the configuration + * + * @param value to validate + */ + default void validate(float value) {} + + /** + * Validates the given float as a byte is valid for the configuration. + * + * @param value to validate + */ + default void validateByte(float value) {} + + PerDimensionValidator DEFAULT_FLOAT_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + validateFloatVectorValue(value); + } + + @Override + public void validateByte(float value) { + throw new IllegalStateException("DEFAULT_FLOAT_VALIDATOR should only be used for float vectors"); + } + }; + + // Validates if it is a finite number and within the fp16 range of [-65504 to 65504]. + PerDimensionValidator DEFAULT_FP16_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + validateFP16VectorValue(value); + } + + @Override + public void validateByte(float value) { + throw new IllegalStateException("DEFAULT_FP16_VALIDATOR should only be used for float vectors"); + } + }; + + PerDimensionValidator DEFAULT_BYTE_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + throw new IllegalStateException("DEFAULT_BYTE_VALIDATOR should only be used for byte values"); + } + + @Override + public void validateByte(float value) { + validateByteVectorValue(value, VectorDataType.BYTE); + } + }; + + PerDimensionValidator DEFAULT_BIT_VALIDATOR = new PerDimensionValidator() { + @Override + public void validate(float value) { + throw new IllegalStateException("DEFAULT_BIT_VALIDATOR should only be used for byte values"); + } + + @Override + public void validateByte(float value) { + validateByteVectorValue(value, VectorDataType.BINARY); + } + }; +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/SpaceVectorValidator.java b/src/main/java/org/opensearch/knn/index/mapper/SpaceVectorValidator.java new file mode 100644 index 0000000000..6ff088604c --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/SpaceVectorValidator.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +import lombok.AllArgsConstructor; +import org.opensearch.knn.index.SpaceType; + +/** + * Confirms that a given vector is valid for the provided space type + */ +@AllArgsConstructor +public class SpaceVectorValidator implements VectorValidator { + + private final SpaceType spaceType; + + @Override + public void validateVector(byte[] vector) { + spaceType.validateVector(vector); + } + + @Override + public void validateVector(float[] vector) { + spaceType.validateVector(vector); + } +} diff --git a/src/main/java/org/opensearch/knn/index/mapper/VectorValidator.java b/src/main/java/org/opensearch/knn/index/mapper/VectorValidator.java new file mode 100644 index 0000000000..f4253ae373 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/mapper/VectorValidator.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.mapper; + +/** + * Class validates vector after it has been parsed + */ +public interface VectorValidator { + /** + * Validate if the given byte vector is supported + * + * @param vector the given vector + */ + default void validateVector(byte[] vector) {} + + /** + * Validate if the given float vector is supported + * + * @param vector the given vector + */ + default void validateVector(float[] vector) {} + + VectorValidator NOOP_VECTOR_VALIDATOR = new VectorValidator() { + }; +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 6d57cb2dd5..80833751ec 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -24,9 +24,9 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.engine.model.QueryContext; +import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.mapper.KNNVectorFieldType; import org.opensearch.knn.index.util.IndexUtil; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; @@ -43,6 +43,7 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; @@ -341,36 +342,47 @@ protected Query doToQuery(QueryShardContext context) { if (!(mappedFieldType instanceof KNNVectorFieldType)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' is not knn_vector type.", this.fieldName)); } - KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldType) mappedFieldType; - int fieldDimension = knnVectorFieldType.getDimension(); - KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext(); - MethodComponentContext methodComponentContext = null; - KNNEngine knnEngine = KNNEngine.DEFAULT; - VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType(); - SpaceType spaceType = knnVectorFieldType.getSpaceType(); + KNNMappingConfig knnMappingConfig = knnVectorFieldType.getKnnMappingConfig(); + final AtomicReference queryConfigFromMapping = new AtomicReference<>(); + int fieldDimension = knnMappingConfig.getDimension(); + knnMappingConfig.getKnnMethodContext() + .ifPresentOrElse( + knnMethodContext -> queryConfigFromMapping.set( + new QueryConfigFromMapping( + knnMethodContext.getKnnEngine(), + knnMethodContext.getMethodComponentContext(), + knnMethodContext.getSpaceType(), + knnVectorFieldType.getVectorDataType() + ) + ), + () -> knnMappingConfig.getModelId().ifPresentOrElse(modelId -> { + ModelMetadata modelMetadata = getModelMetadataForField(modelId); + queryConfigFromMapping.set( + new QueryConfigFromMapping( + modelMetadata.getKnnEngine(), + modelMetadata.getMethodComponentContext(), + modelMetadata.getSpaceType(), + modelMetadata.getVectorDataType() + ) + ); + }, + () -> { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Field '%s' is not built for ANN search.", this.fieldName) + ); + } + ) + ); + KNNEngine knnEngine = queryConfigFromMapping.get().getKnnEngine(); + MethodComponentContext methodComponentContext = queryConfigFromMapping.get().getMethodComponentContext(); + SpaceType spaceType = queryConfigFromMapping.get().getSpaceType(); + VectorDataType vectorDataType = queryConfigFromMapping.get().getVectorDataType(); + VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore); updateQueryStats(vectorQueryType); - if (fieldDimension == -1) { - if (spaceType != null) { - throw new IllegalStateException("Space type should be null when the field uses a model"); - } - // If dimension is not set, the field uses a model and the information needs to be retrieved from there - ModelMetadata modelMetadata = getModelMetadataForField(knnVectorFieldType); - fieldDimension = modelMetadata.getDimension(); - knnEngine = modelMetadata.getKnnEngine(); - spaceType = modelMetadata.getSpaceType(); - methodComponentContext = modelMetadata.getMethodComponentContext(); - vectorDataType = modelMetadata.getVectorDataType(); - - } else if (knnMethodContext != null) { - // If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping - knnEngine = knnMethodContext.getKnnEngine(); - spaceType = knnMethodContext.getSpaceType(); - methodComponentContext = knnMethodContext.getMethodComponentContext(); - } - + // This could be null in the case of when a model did not have serialized methodComponent information final String method = methodComponentContext != null ? methodComponentContext.getName() : null; if (StringUtils.isNotBlank(method)) { final KNNLibrarySearchContext engineSpecificMethodContext = knnEngine.getKNNLibrarySearchContext(method); @@ -492,13 +504,7 @@ protected Query doToQuery(QueryShardContext context) { throw new IllegalArgumentException(String.format(Locale.ROOT, "[%s] requires k or distance or score to be set", NAME)); } - private ModelMetadata getModelMetadataForField(KNNVectorFieldType knnVectorField) { - String modelId = knnVectorField.getModelId(); - - if (modelId == null) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "Field '%s' does not have model.", this.fieldName)); - } - + private ModelMetadata getModelMetadataForField(String modelId) { ModelMetadata modelMetadata = modelDao.getMetadata(modelId); if (!ModelUtil.isModelCreated(modelMetadata)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId)); @@ -568,4 +574,13 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryShardContext) throws I } return super.doRewrite(queryShardContext); } + + @Getter + @AllArgsConstructor + private static class QueryConfigFromMapping { + private final KNNEngine knnEngine; + private final MethodComponentContext methodComponentContext; + private final SpaceType spaceType; + private final VectorDataType vectorDataType; + } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index e0fd250605..efb4bdf932 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -14,7 +14,6 @@ import org.opensearch.indices.SystemIndexDescriptor; import org.opensearch.knn.index.KNNCircuitBreaker; import org.opensearch.knn.index.util.KNNClusterUtil; -import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; @@ -203,7 +202,6 @@ public Collection createComponents( TrainingJobClusterStateListener.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); KNNCircuitBreaker.getInstance().initialize(threadPool, clusterService, client); KNNQueryBuilder.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); - KNNVectorFieldMapperUtil.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); KNNWeight.initialize(ModelDao.OpenSearchKNNModelDao.getInstance()); TrainingModelRequest.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService); diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java index 56c129546f..fb09fb30b0 100644 --- a/src/test/java/org/opensearch/knn/KNNTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNTestCase.java @@ -7,12 +7,18 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.Version; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNLibrarySearchContext; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; +import org.opensearch.knn.index.mapper.KNNMappingConfig; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.core.common.bytes.BytesReference; @@ -20,12 +26,15 @@ import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.test.OpenSearchTestCase; +import java.util.Collections; import java.util.HashSet; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; /** * Base class for integration tests for KNN plugin. Contains several methods for testing KNN ES functionality. @@ -91,4 +100,50 @@ private void initKNNSettings() { public Map xContentBuilderToMap(XContentBuilder xContentBuilder) { return XContentHelper.convertToMap(BytesReference.bytes(xContentBuilder), true, xContentBuilder.contentType()).v2(); } + + public static KNNMethodContext getDefaultKNNMethodContext() { + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + KNNMethodContext defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, methodComponentContext); + methodComponentContext.setIndexVersion(Version.CURRENT); + return defaultInstance; + } + + public static KNNMethodContext getDefaultBinaryKNNMethodContext() { + MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + KNNMethodContext defaultInstance = new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT_BINARY, methodComponentContext); + methodComponentContext.setIndexVersion(Version.CURRENT); + return defaultInstance; + } + + public static KNNMappingConfig getMappingConfigForMethodMapping(KNNMethodContext knnMethodContext, int dimension) { + return new KNNMappingConfig() { + @Override + public Optional getKnnMethodContext() { + return Optional.of(knnMethodContext); + } + + @Override + public int getDimension() { + return dimension; + } + }; + } + + public static KNNMappingConfig getMappingConfigForFlatMapping(int dimension) { + return () -> dimension; + } + + public static KNNMappingConfig getMappingConfigForModelMapping(String modelId, int dimension) { + return new KNNMappingConfig() { + @Override + public Optional getModelId() { + return Optional.of(modelId); + } + + @Override + public int getDimension() { + return dimension; + } + }; + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java index 9a867c58dd..f71fbaae0f 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java @@ -94,7 +94,7 @@ public void testGetSpaceType() { */ public void testValidate() { // Check valid default - this should not throw any exception - assertNull(KNNMethodContext.getDefault().validate()); + assertNull(getDefaultKNNMethodContext().validate()); // Check a valid nmslib method MethodComponentContext hnswMethod = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index e1ebc57085..e87531561a 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -251,62 +251,6 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } - public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException { - // Set information about the segment and the fields - String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); - int docsInSegment = 100; - String fieldName = String.format("test_field%s", randomAlphaOfLength(4)); - - KNNEngine knnEngine = KNNEngine.NMSLIB; - SpaceType spaceType = SpaceType.COSINESIMIL; - int dimension = 16; - - SegmentInfo segmentInfo = KNNCodecTestUtil.segmentInfoBuilder() - .directory(directory) - .segmentName(segmentName) - .docsInSegment(docsInSegment) - .codec(codec) - .build(); - - FieldInfo[] fieldInfoArray = new FieldInfo[] { - KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) - .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") - .addAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512") - .addAttribute(KNNConstants.HNSW_ALGO_M, "16") - .addAttribute(KNNConstants.SPACE_TYPE, spaceType.getValue()) - .build() }; - - FieldInfos fieldInfos = new FieldInfos(fieldInfoArray); - SegmentWriteState state = new SegmentWriteState(null, directory, segmentInfo, fieldInfos, null, IOContext.DEFAULT); - - long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); - long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); - - // Add documents to the field - KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); - TestVectorValues.RandomVectorDocValuesProducer randomVectorDocValuesProducer = new TestVectorValues.RandomVectorDocValuesProducer( - docsInSegment, - dimension - ); - knn80DocValuesConsumer.addKNNBinaryField(fieldInfoArray[0], randomVectorDocValuesProducer, true, true); - - // The document should be created in the correct location - String expectedFile = KNNCodecUtil.buildEngineFileName(segmentName, knnEngine.getVersion(), fieldName, knnEngine.getExtension()); - assertFileInCorrectLocation(state, expectedFile); - - // The footer should be valid - assertValidFooter(state.directory, expectedFile); - - // The document should be readable by nmslib - assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); - - // The graph creation statistics should be updated - assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); - assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); - } - public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException { String segmentName = String.format("test_segment%s", randomAlphaOfLength(4)); int docsInSegment = 100; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index a0b9b32d0e..00cc2b167c 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -14,8 +14,10 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.join.BitSetProducer; +import org.opensearch.Version; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.mapper.MapperService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; @@ -56,6 +58,7 @@ import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.util.Arrays; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -77,6 +80,8 @@ import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_LIMIT_SETTING; @@ -86,14 +91,28 @@ public class KNNCodecTestCase extends KNNTestCase { private static final Codec ACTUAL_CODEC = KNNCodecVersion.current().getDefaultKnnCodecSupplier().get(); - private static FieldType sampleFieldType; + private static final FieldType sampleFieldType; static { + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) + ); + knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); + String parameterString; + try { + parameterString = XContentFactory.jsonBuilder() + .map(knnMethodContext.getKnnEngine().getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters()) + .toString(); + } catch (IOException e) { + throw new RuntimeException(e); + } + sampleFieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); - sampleFieldType.putAttribute(KNNConstants.KNN_METHOD, KNNConstants.METHOD_HNSW); - sampleFieldType.putAttribute(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()); - sampleFieldType.putAttribute(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()); - sampleFieldType.putAttribute(KNNConstants.HNSW_ALGO_M, "32"); - sampleFieldType.putAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512"); + sampleFieldType.putAttribute(KNNVectorFieldMapper.KNN_FIELD, "true"); + sampleFieldType.putAttribute(KNNConstants.KNN_ENGINE, knnMethodContext.getKnnEngine().getName()); + sampleFieldType.putAttribute(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); + sampleFieldType.putAttribute(KNNConstants.PARAMETERS, parameterString); sampleFieldType.freeze(); } private static final String FIELD_NAME_ONE = "test_vector_one"; @@ -309,8 +328,19 @@ public void testKnnVectorIndex( SpaceType.L2, new MethodComponentContext(METHOD_HNSW, Map.of(HNSW_ALGO_M, 16, HNSW_ALGO_EF_CONSTRUCTION, 256)) ); - final KNNVectorFieldType mappedFieldType1 = new KNNVectorFieldType(FIELD_NAME_ONE, Map.of(), 3, knnMethodContext); - final KNNVectorFieldType mappedFieldType2 = new KNNVectorFieldType(FIELD_NAME_TWO, Map.of(), 2, knnMethodContext); + + final KNNVectorFieldType mappedFieldType1 = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 3) + ); + final KNNVectorFieldType mappedFieldType2 = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 2) + ); when(mapperService.fieldType(eq(FIELD_NAME_ONE))).thenReturn(mappedFieldType1); when(mapperService.fieldType(eq(FIELD_NAME_TWO))).thenReturn(mappedFieldType2); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index c95568be22..f06ff79353 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -103,7 +103,7 @@ public class KNNVectorFieldMapperTests extends KNNTestCase { public void testBuilder_getParameters() { String fieldName = "test-field-name"; ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder(fieldName, modelDao, CURRENT, null); assertEquals(7, builder.getParameters().size()); List actualParams = builder.getParameters().stream().map(a -> a.name).collect(Collectors.toList()); @@ -114,7 +114,7 @@ public void testBuilder_getParameters() { public void testBuilder_build_fromKnnMethodContext() { // Check that knnMethodContext takes precedent over both model and legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -126,6 +126,7 @@ public void testBuilder_build_fromKnnMethodContext() { .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) .put(KNNSettings.KNN_ALGO_PARAM_M, m) .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .put(KNN_INDEX, true) .build(); builder.knnMethodContext.setValue( @@ -139,19 +140,17 @@ public void testBuilder_build_fromKnnMethodContext() { ) ); - builder.modelId.setValue("Random modelId"); - Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); - assertNotNull(knnVectorFieldMapper.knnMethod); - assertNull(knnVectorFieldMapper.modelId); + assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); + assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty()); } public void testBuilder_build_fromModel() { // Check that modelContext takes precedent over legacy ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); SpaceType spaceType = SpaceType.COSINESIMIL; int m = 17; @@ -163,6 +162,7 @@ public void testBuilder_build_fromModel() { .put(KNNSettings.KNN_SPACE_TYPE, spaceType.getValue()) .put(KNNSettings.KNN_ALGO_PARAM_M, m) .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .put(KNN_INDEX, true) .build(); String modelId = "Random modelId"; @@ -184,14 +184,14 @@ public void testBuilder_build_fromModel() { when(modelDao.getMetadata(modelId)).thenReturn(mockedModelMetadata); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof ModelFieldMapper); - assertNotNull(knnVectorFieldMapper.modelId); - assertNull(knnVectorFieldMapper.knnMethod); + assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isPresent()); + assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isEmpty()); } public void testBuilder_build_fromLegacy() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); int m = 17; int efConstruction = 17; @@ -201,37 +201,22 @@ public void testBuilder_build_fromLegacy() { .put(settings(CURRENT).build()) .put(KNNSettings.KNN_ALGO_PARAM_M, m) .put(KNNSettings.KNN_ALGO_PARAM_EF_CONSTRUCTION, efConstruction) + .put(KNN_INDEX, true) .build(); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper); - assertNull(knnVectorFieldMapper.modelId); - assertNull(knnVectorFieldMapper.knnMethod); - assertEquals(SpaceType.L2.getValue(), ((LegacyFieldMapper) knnVectorFieldMapper).spaceType); - } - - public void testBuilder_whenKnnFalseWithBinary_thenSetHammingAsDefault() { - // Check legacy is picked up if model context and method context are not set - ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); - builder.vectorDataType.setValue(VectorDataType.BINARY); - builder.dimension.setValue(8); - - // Setup settings - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); - - Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); - KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper); - assertEquals(SpaceType.HAMMING.getValue(), ((LegacyFieldMapper) knnVectorFieldMapper).spaceType); + assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); + assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().isPresent()); + assertTrue(knnVectorFieldMapper.fieldType().getKnnMappingConfig().getModelId().isEmpty()); + assertEquals(SpaceType.L2, knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType()); } public void testBuilder_parse_fromKnnMethodContext_luceneEngine() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); @@ -317,7 +302,7 @@ public void testTypeParser_parse_fromKnnMethodContext_invalidDimension() throws String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).put(KNN_INDEX, true).build()).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); @@ -616,7 +601,7 @@ public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOExcep String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); ModelDao modelDao = mock(ModelDao.class); KNNVectorFieldMapper.TypeParser typeParser = new KNNVectorFieldMapper.TypeParser(() -> modelDao); @@ -646,12 +631,18 @@ public void testKNNVectorFieldMapper_merge_fromKnnMethodContext() throws IOExcep // merge with itself - should be successful KNNVectorFieldMapper knnVectorFieldMapperMerge1 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper1); - assertEquals(knnVectorFieldMapper1.knnMethod, knnVectorFieldMapperMerge1.knnMethod); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getKnnMethodContext().get(), + knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getKnnMethodContext().get() + ); // merge with another mapper of the same field with same context KNNVectorFieldMapper knnVectorFieldMapper2 = builder.build(builderContext); KNNVectorFieldMapper knnVectorFieldMapperMerge2 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper2); - assertEquals(knnVectorFieldMapper1.knnMethod, knnVectorFieldMapperMerge2.knnMethod); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getKnnMethodContext().get(), + knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getKnnMethodContext().get() + ); // merge with another mapper of the same field with different context xContentBuilder = XContentFactory.jsonBuilder() @@ -676,7 +667,7 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { String fieldName = "test-field-name"; String indexName = "test-index-name"; - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); String modelId = "test-id"; int dimension = 133; @@ -715,12 +706,18 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { // merge with itself - should be successful KNNVectorFieldMapper knnVectorFieldMapperMerge1 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper1); - assertEquals(knnVectorFieldMapper1.modelId, knnVectorFieldMapperMerge1.modelId); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getModelId().get(), + knnVectorFieldMapperMerge1.fieldType().getKnnMappingConfig().getModelId().get() + ); // merge with another mapper of the same field with same context KNNVectorFieldMapper knnVectorFieldMapper2 = builder.build(builderContext); KNNVectorFieldMapper knnVectorFieldMapperMerge2 = (KNNVectorFieldMapper) knnVectorFieldMapper1.merge(knnVectorFieldMapper2); - assertEquals(knnVectorFieldMapper1.modelId, knnVectorFieldMapperMerge2.modelId); + assertEquals( + knnVectorFieldMapper1.fieldType().getKnnMappingConfig().getModelId().get(), + knnVectorFieldMapperMerge2.fieldType().getKnnMappingConfig().getModelId().get() + ); // merge with another mapper of the same field with different context xContentBuilder = XContentFactory.jsonBuilder() @@ -754,18 +751,19 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { when(parseContext.doc()).thenReturn(document); when(parseContext.path()).thenReturn(contentPath); - LuceneFieldMapper luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); - doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper) - .getFloatsFromContext(parseContext, TEST_DIMENSION, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())); - doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); - doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - luceneFieldMapper.parseCreateField( - parseContext, - TEST_DIMENSION, - luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), - VectorDataType.FLOAT + LuceneFieldMapper luceneFieldMapper = Mockito.spy( + LuceneFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + Collections.emptyMap(), + VectorDataType.FLOAT, + TEST_DIMENSION, + getDefaultKNNMethodContext(), + inputBuilder.build() + ) ); + doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + doNothing().when(luceneFieldMapper).validatePreparse(); + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnVectorField List fields = document.getFields(); @@ -798,19 +796,26 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { when(parseContext.path()).thenReturn(contentPath); inputBuilder.hasDocValues(false); - luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); - doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper) - .getFloatsFromContext(parseContext, TEST_DIMENSION, new MethodComponentContext(METHOD_HNSW, Collections.emptyMap())); - doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); - doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - - luceneFieldMapper.parseCreateField( - parseContext, - TEST_DIMENSION, - luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), - VectorDataType.FLOAT + + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.DEFAULT, + new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) ); + luceneFieldMapper = Mockito.spy( + LuceneFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + Collections.emptyMap(), + VectorDataType.FLOAT, + TEST_DIMENSION, + knnMethodContext, + inputBuilder.build() + ) + ); + doReturn(Optional.of(TEST_VECTOR)).when(luceneFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + doNothing().when(luceneFieldMapper).validatePreparse(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); // Document should have 1 field: one for KnnVectorField fields = document.getFields(); @@ -834,19 +839,21 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { when(parseContext.doc()).thenReturn(document); when(parseContext.path()).thenReturn(contentPath); - LuceneFieldMapper luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); + LuceneFieldMapper luceneFieldMapper = Mockito.spy( + LuceneFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + Collections.emptyMap(), + VectorDataType.BYTE, + TEST_DIMENSION, + getDefaultKNNMethodContext(), + inputBuilder.build() + ) + ); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) .getBytesFromContext(parseContext, TEST_DIMENSION, VectorDataType.BYTE); - doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); - doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - - luceneFieldMapper.parseCreateField( - parseContext, - TEST_DIMENSION, - luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), - VectorDataType.BYTE - ); + doNothing().when(luceneFieldMapper).validatePreparse(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.BYTE); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnByteVectorField List fields = document.getFields(); @@ -878,19 +885,21 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { when(parseContext.path()).thenReturn(contentPath); inputBuilder.hasDocValues(false); - luceneFieldMapper = Mockito.spy(new LuceneFieldMapper(inputBuilder.build())); + luceneFieldMapper = Mockito.spy( + LuceneFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + Collections.emptyMap(), + VectorDataType.BYTE, + TEST_DIMENSION, + getDefaultKNNMethodContext(), + inputBuilder.build() + ) + ); doReturn(Optional.of(TEST_BYTE_VECTOR)).when(luceneFieldMapper) .getBytesFromContext(parseContext, TEST_DIMENSION, VectorDataType.BYTE); - doNothing().when(luceneFieldMapper).validateIfCircuitBreakerIsNotTriggered(); - doNothing().when(luceneFieldMapper).validateIfKNNPluginEnabled(); - - luceneFieldMapper.parseCreateField( - parseContext, - TEST_DIMENSION, - luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), - VectorDataType.BYTE - ); + doNothing().when(luceneFieldMapper).validatePreparse(); + + luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.BYTE); // Document should have 1 field: one for KnnByteVectorField fields = document.getFields(); @@ -970,10 +979,10 @@ private void testBuilderWithBinaryDataType( String expectedErrMsg ) { ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); // Setup settings - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); builder.knnMethodContext.setValue( new KNNMethodContext(knnEngine, spaceType, new MethodComponentContext(method, Collections.emptyMap())) @@ -986,7 +995,10 @@ private void testBuilderWithBinaryDataType( KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); assertTrue(knnVectorFieldMapper instanceof MethodFieldMapper); if (SpaceType.UNDEFINED == spaceType) { - assertEquals(SpaceType.HAMMING, knnVectorFieldMapper.fieldType().spaceType); + assertEquals( + SpaceType.HAMMING, + knnVectorFieldMapper.fieldType().getKnnMappingConfig().getKnnMethodContext().get().getSpaceType() + ); } } else { Exception ex = expectThrows(Exception.class, () -> builder.build(builderContext)); @@ -996,10 +1008,10 @@ private void testBuilderWithBinaryDataType( public void testBuilder_whenBinaryFaissHNSWWithSQ_thenException() { ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); // Setup settings - Settings settings = Settings.builder().put(settings(CURRENT).build()).build(); + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); builder.knnMethodContext.setValue( new KNNMethodContext( @@ -1022,7 +1034,7 @@ public void testBuilder_whenBinaryFaissHNSWWithSQ_thenException() { public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); @@ -1031,13 +1043,13 @@ public void testBuilder_whenBinaryWithLegacyKNNDisabled_thenValid() { Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); KNNVectorFieldMapper knnVectorFieldMapper = builder.build(builderContext); - assertTrue(knnVectorFieldMapper instanceof LegacyFieldMapper); + assertTrue(knnVectorFieldMapper instanceof FlatVectorFieldMapper); } public void testBuilder_whenBinaryWithLegacyKNNEnabled_thenException() { // Check legacy is picked up if model context and method context are not set ModelDao modelDao = mock(ModelDao.class); - KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT); + KNNVectorFieldMapper.Builder builder = new KNNVectorFieldMapper.Builder("test-field-name-1", modelDao, CURRENT, null); builder.vectorDataType.setValue(VectorDataType.BINARY); builder.dimension.setValue(8); @@ -1052,29 +1064,14 @@ public void testBuilder_whenBinaryWithLegacyKNNEnabled_thenException() { private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder( VectorDataType vectorDataType ) { - KNNMethodContext knnMethodContext = new KNNMethodContext( - KNNEngine.LUCENE, - SpaceType.DEFAULT, - new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()) - ); - - KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType( - TEST_FIELD_NAME, - Collections.emptyMap(), - TEST_DIMENSION, - knnMethodContext, - vectorDataType - ); - return LuceneFieldMapper.CreateLuceneFieldMapperInput.builder() .name(TEST_FIELD_NAME) - .mappedFieldType(knnVectorFieldType) .multiFields(FieldMapper.MultiFields.empty()) .copyTo(FieldMapper.CopyTo.empty()) .hasDocValues(true) .vectorDataType(vectorDataType) .ignoreMalformed(new Explicit<>(true, true)) - .knnMethodContext(knnMethodContext); + .originalKnnMethodContext(getDefaultKNNMethodContext()); } private static float[] createInitializedFloatArray(int dimension, float value) { diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index 31da12d669..8ace5557ef 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -21,9 +21,6 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.indices.ModelDao; -import org.opensearch.knn.indices.ModelMetadata; -import org.opensearch.knn.indices.ModelState; import java.util.Arrays; import java.util.Collections; @@ -61,64 +58,24 @@ public void testStoredFields_whenVectorIsFloatType_thenSucceed() { public void testGetExpectedVectorLengthSuccess() { KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); - when(knnVectorFieldType.getDimension()).thenReturn(3); - + when(knnVectorFieldType.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 3)); KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeBinary.getDimension()).thenReturn(8); + when(knnVectorFieldTypeBinary.getKnnMappingConfig()).thenReturn( + getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 8) + ); when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY); KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); + when(knnVectorFieldTypeModelBased.getKnnMappingConfig()).thenReturn( + getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 8) + ); String modelId = "test-model"; - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); - - ModelDao modelDao = mock(ModelDao.class); - ModelMetadata modelMetadata = mock(ModelMetadata.class); - when(modelMetadata.getState()).thenReturn(ModelState.CREATED); - when(modelMetadata.getDimension()).thenReturn(4); - when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); - - KNNVectorFieldMapperUtil.initialize(modelDao); - + when(knnVectorFieldTypeModelBased.getKnnMappingConfig()).thenReturn(getMappingConfigForModelMapping(modelId, 4)); assertEquals(3, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType)); assertEquals(1, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldTypeBinary)); assertEquals(4, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldTypeModelBased)); } - public void testGetExpectedVectorLengthFailure() { - KNNVectorFieldType knnVectorFieldTypeModelBased = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeModelBased.getDimension()).thenReturn(-1); - String modelId = "test-model"; - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(modelId); - - ModelDao modelDao = mock(ModelDao.class); - ModelMetadata modelMetadata = mock(ModelMetadata.class); - when(modelMetadata.getState()).thenReturn(ModelState.TRAINING); - when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); - - KNNVectorFieldMapperUtil.initialize(modelDao); - - IllegalArgumentException e = expectThrows( - IllegalArgumentException.class, - () -> KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldTypeModelBased) - ); - assertEquals(String.format("Model ID '%s' is not created.", modelId), e.getMessage()); - - when(knnVectorFieldTypeModelBased.getModelId()).thenReturn(null); - KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); - MethodComponentContext methodComponentContext = mock(MethodComponentContext.class); - String fieldName = "test-field"; - when(methodComponentContext.getName()).thenReturn(fieldName); - when(knnMethodContext.getMethodComponentContext()).thenReturn(methodComponentContext); - when(knnVectorFieldTypeModelBased.getKnnMethodContext()).thenReturn(knnMethodContext); - - e = expectThrows( - IllegalArgumentException.class, - () -> KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldTypeModelBased) - ); - assertEquals(String.format("Field '%s' does not have model.", fieldName), e.getMessage()); - } - public void testValidateVectorDataType_whenBinaryFaissHNSW_thenValid() { validateValidateVectorDataType(KNNEngine.FAISS, KNNConstants.METHOD_HNSW, VectorDataType.BINARY, null); } diff --git a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java index dcd2557405..faae3e35d2 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/MethodFieldMapperTests.java @@ -5,33 +5,35 @@ package org.opensearch.knn.index.mapper; -import junit.framework.TestCase; +import org.opensearch.Version; import org.opensearch.index.mapper.FieldMapper; -import org.opensearch.knn.index.engine.KNNMethodContext; -import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.KNNMethodContext; import java.util.Collections; -public class MethodFieldMapperTests extends TestCase { - public void testMethodFieldMapper_whenVectorDataTypeIsGiven_thenSetItInFieldType() { - KNNVectorFieldType mappedFieldType = new KNNVectorFieldType( - "testField", - Collections.emptyMap(), - 1, - VectorDataType.BINARY, - SpaceType.HAMMING - ); - MethodFieldMapper mappers = new MethodFieldMapper( - "simpleName", - mappedFieldType, - null, - new FieldMapper.CopyTo.Builder().build(), - KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED, - true, - true, - KNNMethodContext.getDefault() +public class MethodFieldMapperTests extends KNNTestCase { + public void testMethodFieldMapper_whenVectorDataTypeAndContextMismatch_thenThrow() { + // Expect that we cannot create the mapper with an invalid field type + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + expectThrows( + IllegalArgumentException.class, + () -> MethodFieldMapper.createFieldMapper( + "testField", + "simpleName", + Collections.emptyMap(), + VectorDataType.BINARY, + 1, + knnMethodContext, + knnMethodContext, + null, + new FieldMapper.CopyTo.Builder().build(), + KNNVectorFieldMapper.Defaults.IGNORE_MALFORMED, + true, + true, + Version.CURRENT + ) ); - assertEquals(VectorDataType.BINARY, mappers.fieldType().vectorDataType); } } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 0b918bd9ed..25982fb7d6 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -185,9 +185,8 @@ public void testDoToQuery_Normal() throws Exception { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertEquals(knnQueryBuilder.getK(), query.getK()); @@ -207,7 +206,6 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( @@ -215,7 +213,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); float resultSimilarity = KNNEngine.LUCENE.distanceToRadialThreshold(MAX_DISTANCE, SpaceType.L2); @@ -239,7 +237,6 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( @@ -247,7 +244,7 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertTrue(query.toString().contains("resultSimilarity=" + 0.5f)); } @@ -266,16 +263,14 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSuppor QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -298,16 +293,14 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -325,16 +318,14 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSuppor QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -352,16 +343,14 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -383,16 +372,14 @@ public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSu QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -416,16 +403,14 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) - ); + KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); when(indexSettings.getMaxResultWindow()).thenReturn(1000); @@ -444,7 +429,6 @@ public void testDoToQuery_whenRadialSearchOnBinaryIndex_thenException() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(8); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); MethodComponentContext methodComponentContext = new MethodComponentContext( @@ -452,7 +436,7 @@ public void testDoToQuery_whenRadialSearchOnBinaryIndex_thenException() { ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.HAMMING, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 8)); Exception e = expectThrows(UnsupportedOperationException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); assertTrue(e.getMessage().contains("Binary data type does not support radial search")); } @@ -470,15 +454,13 @@ public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); // When @@ -504,14 +486,13 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); assertNotNull(query); @@ -531,14 +512,13 @@ public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenS QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); assertNotNull(query); @@ -553,15 +533,13 @@ public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); // When @@ -586,10 +564,12 @@ public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParamete QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - when(mockKNNVectorField.getDimension()).thenReturn(4); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn( - new KNNMethodContext(KNNEngine.LUCENE, SpaceType.COSINESIMIL, new MethodComponentContext("hnsw", Map.of())) + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.LUCENE, + SpaceType.COSINESIMIL, + new MethodComponentContext("hnsw", Map.of()) ); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() @@ -609,15 +589,13 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); MethodComponentContext methodComponentContext = new MethodComponentContext( org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of() ); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, methodComponentContext); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } @@ -632,15 +610,12 @@ public void testDoToQuery_FromModel() { when(mockQueryShardContext.index()).thenReturn(dummyIndex); // Dimension is -1. In this case, model metadata will need to provide dimension - when(mockKNNVectorField.getDimension()).thenReturn(-K); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; - when(mockKNNVectorField.getModelId()).thenReturn(modelId); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForModelMapping(modelId, 4)); // Mock the modelDao to return mocked modelMetadata ModelMetadata modelMetadata = mock(ModelMetadata.class); - when(modelMetadata.getDimension()).thenReturn(4); when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); @@ -672,14 +647,11 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(-K); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; - when(mockKNNVectorField.getModelId()).thenReturn(modelId); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForModelMapping(modelId, 4)); ModelMetadata modelMetadata = mock(ModelMetadata.class); - when(modelMetadata.getDimension()).thenReturn(4); when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); @@ -709,15 +681,11 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - - when(mockKNNVectorField.getDimension()).thenReturn(-K); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(null); String modelId = "test-model-id"; - when(mockKNNVectorField.getModelId()).thenReturn(modelId); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForModelMapping(modelId, 4)); ModelMetadata modelMetadata = mock(ModelMetadata.class); - when(modelMetadata.getDimension()).thenReturn(4); when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); @@ -744,10 +712,10 @@ public void testDoToQuery_InvalidDimensions() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(400); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 400)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); - when(mockKNNVectorField.getDimension()).thenReturn(K); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), K)); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } @@ -769,9 +737,10 @@ public void testDoToQuery_InvalidZeroFloatVector() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); + when(knnMethodContext.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, @@ -790,9 +759,10 @@ public void testDoToQuery_InvalidZeroByteVector() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BYTE); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); + when(knnMethodContext.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, @@ -919,9 +889,8 @@ public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); @@ -946,9 +915,8 @@ public void testRadialSearch_whenEfSearchIsSet_whenLuceneEngine_thenThrowExcepti KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); @@ -972,9 +940,8 @@ public void testRadialSearch_whenEfSearchIsSet_whenFaissEngine_thenSuccess() { KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); - when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(knnMethodContext, 4)); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); IndexSettings indexSettings = mock(IndexSettings.class); when(mockQueryShardContext.getIndexSettings()).thenReturn(indexSettings); @@ -992,9 +959,8 @@ public void testDoToQuery_whenBinary_thenValid() throws Exception { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(32); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.HAMMING); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 32)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); KNNQuery query = (KNNQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); assertArrayEquals(expectedQueryVector, query.getByteQueryVector()); @@ -1008,9 +974,8 @@ public void testDoToQuery_whenBinaryWithInvalidDimension_thenException() throws QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); - when(mockKNNVectorField.getDimension()).thenReturn(8); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BINARY); - when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.HAMMING); + when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 8)); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); Exception ex = expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); assertTrue(ex.getMessage(), ex.getMessage().contains("invalid dimension")); diff --git a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java index f9a9704d0b..d1288c5f34 100644 --- a/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/integ/KNNScriptScoringIT.java @@ -46,6 +46,7 @@ import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.knn.KNNTestCase.getMappingConfigForFlatMapping; import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; @@ -744,18 +745,20 @@ private Map createDataset( } private BiFunction getScoreFunction(SpaceType spaceType, float[] queryVector) { - KNNVectorFieldType knnVectorFieldType = new KNNVectorFieldType( - FIELD_NAME, - Collections.emptyMap(), - SpaceType.HAMMING == spaceType ? queryVector.length * 8 : queryVector.length, - SpaceType.HAMMING == spaceType ? VectorDataType.BINARY : VectorDataType.FLOAT, - null - ); List target = new ArrayList<>(queryVector.length); for (float f : queryVector) { target.add(f); } - KNNScoringSpace knnScoringSpace = KNNScoringSpaceFactory.create(spaceType.getValue(), target, knnVectorFieldType); + KNNScoringSpace knnScoringSpace = KNNScoringSpaceFactory.create( + spaceType.getValue(), + target, + new KNNVectorFieldType( + FIELD_NAME, + Collections.emptyMap(), + SpaceType.HAMMING == spaceType ? VectorDataType.BINARY : VectorDataType.FLOAT, + getMappingConfigForFlatMapping(SpaceType.HAMMING == spaceType ? queryVector.length * 8 : queryVector.length) + ) + ); switch (spaceType) { case L1: case L2: diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java index 823d210803..c41e9763b5 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceFactoryTests.java @@ -19,9 +19,11 @@ public class KNNScoringSpaceFactoryTests extends KNNTestCase { public void testValidSpaces() { KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); - when(knnVectorFieldType.getDimension()).thenReturn(3); + when(knnVectorFieldType.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 3)); KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeBinary.getDimension()).thenReturn(24); + when(knnVectorFieldTypeBinary.getKnnMappingConfig()).thenReturn( + getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 24) + ); when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY); NumberFieldMapper.NumberFieldType numberFieldType = new NumberFieldMapper.NumberFieldType( "field", @@ -66,9 +68,11 @@ public void testValidSpaces() { public void testInvalidSpace() { List floatQueryObject = List.of(1.0f, 1.0f, 1.0f); KNNVectorFieldType knnVectorFieldType = mock(KNNVectorFieldType.class); - when(knnVectorFieldType.getDimension()).thenReturn(3); + when(knnVectorFieldType.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 3)); KNNVectorFieldType knnVectorFieldTypeBinary = mock(KNNVectorFieldType.class); - when(knnVectorFieldTypeBinary.getDimension()).thenReturn(24); + when(knnVectorFieldTypeBinary.getKnnMappingConfig()).thenReturn( + getMappingConfigForMethodMapping(getDefaultBinaryKNNMethodContext(), 24) + ); when(knnVectorFieldTypeBinary.getVectorDataType()).thenReturn(VectorDataType.BINARY); // Verify diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java index 6c557c8dd5..4fc549d6bc 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceTests.java @@ -57,8 +57,13 @@ private void expectThrowsExceptionWithKNNFieldWithBinaryDataType(Class clazz) th public void testL2_whenValid_thenSucceed() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + KNNVectorFieldType fieldType = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 3) + ); KNNScoringSpace.L2 l2 = new KNNScoringSpace.L2(arrayListQueryObject, fieldType); assertEquals(1F, l2.getScoringMethod().apply(arrayFloat, arrayFloat), 0.1F); } @@ -73,9 +78,13 @@ public void testCosineSimilarity_whenValid_thenSucceed() { float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject = new ArrayList<>(Arrays.asList(2.0, 4.0, 6.0)); float[] arrayFloat2 = new float[] { 2.0f, 4.0f, 6.0f }; - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - - KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + KNNVectorFieldType fieldType = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 3) + ); KNNScoringSpace.CosineSimilarity cosineSimilarity = new KNNScoringSpace.CosineSimilarity(arrayListQueryObject, fieldType); assertEquals(2F, cosineSimilarity.getScoringMethod().apply(arrayFloat2, arrayFloat), 0.1F); @@ -92,8 +101,13 @@ public void testCosineSimilarity_whenValid_thenSucceed() { } public void testCosineSimilarity_whenZeroVector_thenException() { - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); - KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); + KNNVectorFieldType fieldType = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 3) + ); final List queryZeroVector = List.of(0.0f, 0.0f, 0.0f); IllegalArgumentException exception1 = expectThrows( @@ -116,9 +130,14 @@ public void testInnerProd_whenValid_thenSucceed() { float[] arrayFloat_case1 = new float[] { 1.0f, 2.0f, 3.0f }; List arrayListQueryObject_case1 = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); float[] arrayFloat2_case1 = new float[] { 1.0f, 1.0f, 1.0f }; - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); - KNNVectorFieldType fieldType = new KNNVectorFieldType("test", Collections.emptyMap(), 3, knnMethodContext); + KNNVectorFieldType fieldType = new KNNVectorFieldType( + "test", + Collections.emptyMap(), + VectorDataType.FLOAT, + getMappingConfigForMethodMapping(knnMethodContext, 3) + ); KNNScoringSpace.InnerProd innerProd = new KNNScoringSpace.InnerProd(arrayListQueryObject_case1, fieldType); assertEquals(7.0F, innerProd.getScoringMethod().apply(arrayFloat_case1, arrayFloat2_case1), 0.001F); @@ -183,14 +202,14 @@ public void testHammingBit_Base64() { public void testHamming_whenKNNFieldType_thenSucceed() { List arrayListQueryObject = new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0)); - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); KNNVectorFieldType fieldType = new KNNVectorFieldType( "test", Collections.emptyMap(), - 8 * arrayListQueryObject.size(), - knnMethodContext, - VectorDataType.BINARY + VectorDataType.BINARY, + getMappingConfigForMethodMapping(knnMethodContext, 8 * arrayListQueryObject.size()) ); + KNNScoringSpace.Hamming hamming = new KNNScoringSpace.Hamming(arrayListQueryObject, fieldType); float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f }; diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java index 781ed2350b..2374e4f7bb 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScoringSpaceUtilTests.java @@ -64,7 +64,7 @@ public void testParseKNNVectorQuery() { KNNVectorFieldType fieldType = mock(KNNVectorFieldType.class); - when(fieldType.getDimension()).thenReturn(3); + when(fieldType.getKnnMappingConfig()).thenReturn(getMappingConfigForMethodMapping(getDefaultKNNMethodContext(), 3)); assertArrayEquals(arrayFloat, KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 3, VectorDataType.FLOAT), 0.1f); expectThrows( diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java index 3515c690d8..8cff4dfa14 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java @@ -23,7 +23,6 @@ import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -303,7 +302,7 @@ public void testTrainingIndexSize() { // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( null, - KNNMethodContext.getDefault(), + getDefaultKNNMethodContext(), dimension, trainingIndexName, "training-field", @@ -350,7 +349,7 @@ public void testTrainIndexSize_whenDataTypeIsBinary() { // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( null, - KNNMethodContext.getDefault(), + getDefaultKNNMethodContext(), dimension, trainingIndexName, "training-field", @@ -398,7 +397,7 @@ public void testTrainIndexSize_whenDataTypeIsByte() { // Setup the request TrainingModelRequest trainingModelRequest = new TrainingModelRequest( null, - KNNMethodContext.getDefault(), + getDefaultKNNMethodContext(), dimension, trainingIndexName, "training-field", diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 53e59129e8..83d39cfdc7 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -48,7 +48,7 @@ public class TrainingModelRequestTests extends KNNTestCase { public void testStreams() throws IOException { String modelId = "test-model-id"; - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; @@ -105,7 +105,7 @@ public void testStreams() throws IOException { public void testGetters() { String modelId = "test-model-id"; - KNNMethodContext knnMethodContext = KNNMethodContext.getDefault(); + KNNMethodContext knnMethodContext = getDefaultKNNMethodContext(); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; From 5a5351ff79059f48518a45bde3af89b8e970ff43 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Mon, 12 Aug 2024 09:12:05 -0700 Subject: [PATCH 5/6] Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation (#1945) Signed-off-by: Navneet Verma --- CHANGELOG.md | 3 +- .../org/opensearch/knn/index/KNNSettings.java | 33 ++++- .../index/mapper/FlatVectorFieldMapper.java | 4 + .../index/mapper/KNNVectorFieldMapper.java | 45 +++--- .../mapper/KNNVectorFieldMapperUtil.java | 16 +++ .../knn/index/mapper/LuceneFieldMapper.java | 8 +- .../knn/index/mapper/MethodFieldMapper.java | 20 +++ .../knn/index/mapper/ModelFieldMapper.java | 20 ++- .../knn/index/codec/KNNCodecTestCase.java | 10 +- .../mapper/KNNVectorFieldMapperTests.java | 128 ++++++++++++++++-- .../mapper/KNNVectorFieldMapperUtilTests.java | 22 +++ 11 files changed, 268 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 81c90802ed..f1dc5b14d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874) * Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917) * Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) +* Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945) ### Infrastructure ### Documentation ### Maintenance @@ -32,4 +33,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931) * Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) * Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877) -* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939) \ No newline at end of file +* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939) diff --git a/src/main/java/org/opensearch/knn/index/KNNSettings.java b/src/main/java/org/opensearch/knn/index/KNNSettings.java index 33c7ff410b..4ced38b38e 100644 --- a/src/main/java/org/opensearch/knn/index/KNNSettings.java +++ b/src/main/java/org/opensearch/knn/index/KNNSettings.java @@ -82,6 +82,12 @@ public class KNNSettings { public static final String MODEL_CACHE_SIZE_LIMIT = "knn.model.cache.size.limit"; public static final String ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD = "index.knn.advanced.filtered_exact_search_threshold"; public static final String KNN_FAISS_AVX2_DISABLED = "knn.faiss.avx2.disabled"; + /** + * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the + * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added + * for native engines. + */ + public static final String KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED = "knn.use.format.enabled"; /** * Default setting values @@ -255,6 +261,17 @@ public class KNNSettings { NodeScope ); + /** + * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the + * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added + * for native engines. + */ + public static final Setting KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING = Setting.boolSetting( + KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED, + false, + NodeScope + ); + /** * Dynamic settings */ @@ -379,6 +396,10 @@ private Setting getSetting(String key) { return KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING; } + if (KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED.equals(key)) { + return KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING; + } + throw new IllegalArgumentException("Cannot find setting by key [" + key + "]"); } @@ -397,7 +418,8 @@ public List> getSettings() { MODEL_CACHE_SIZE_LIMIT_SETTING, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING, KNN_FAISS_AVX2_DISABLED_SETTING, - KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING + KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING, + KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING ); return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream())) .collect(Collectors.toList()); @@ -443,6 +465,15 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) { .getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE); } + /** + * TODO: This setting is only added to ensure that main branch of k_NN plugin doesn't break till other parts of the + * code is getting ready. Will remove this setting once all changes related to integration of KNNVectorsFormat is added + * for native engines. + */ + public static boolean getIsLuceneVectorFormatEnabled() { + return KNNSettings.state().getSettingValue(KNNSettings.KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED); + } + public void initialize(Client client, ClusterService clusterService) { this.client = client; this.clusterService = clusterService; diff --git a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java index fffff30f41..146b5132fe 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/FlatVectorFieldMapper.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.mapper; import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DocValuesType; import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.knn.index.VectorDataType; @@ -57,8 +58,11 @@ private FlatVectorFieldMapper( Version indexCreatedVersion ) { super(simpleName, mappedFieldType, multiFields, copyTo, ignoreMalformed, stored, hasDocValues, indexCreatedVersion, null); + // setting it explicitly false here to ensure that when flatmapper is used Lucene based Vector field is not created. + this.useLuceneBasedVectorField = false; this.perDimensionValidator = selectPerDimensionValidator(vectorDataType); this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + this.fieldType.setDocValuesType(DocValuesType.BINARY); this.fieldType.freeze(); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 40eaa12aeb..5d4d3ca58c 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -16,7 +16,8 @@ import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; -import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.IndexOptions; import org.opensearch.Version; import org.opensearch.common.Explicit; @@ -456,6 +457,7 @@ public Mapper.Builder parse(String name, Map node, ParserCont protected boolean hasDocValues; protected VectorDataType vectorDataType; protected ModelDao modelDao; + protected boolean useLuceneBasedVectorField; // We need to ensure that the original KNNMethodContext as parsed is stored to initialize the // Builder for serialization. So, we need to store it here. This is mainly to ensure that the legacy field mapper @@ -497,16 +499,29 @@ protected void parseCreateField(ParseContext context) throws IOException { parseCreateField(context, fieldType().getKnnMappingConfig().getDimension(), fieldType().getVectorDataType()); } + private Field createVectorField(float[] vectorValue) { + if (useLuceneBasedVectorField) { + return new KnnFloatVectorField(name(), vectorValue, fieldType); + } + return new VectorField(name(), vectorValue, fieldType); + } + + private Field createVectorField(byte[] vectorValue) { + if (useLuceneBasedVectorField) { + return new KnnByteVectorField(name(), vectorValue, fieldType); + } + return new VectorField(name(), vectorValue, fieldType); + } + /** * Function returns a list of fields to be indexed when the vector is float type. * * @param array array of floats - * @param fieldType {@link FieldType} * @return {@link List} of {@link Field} */ - protected List getFieldsForFloatVector(final float[] array, final FieldType fieldType) { + protected List getFieldsForFloatVector(final float[] array) { final List fields = new ArrayList<>(); - fields.add(new VectorField(name(), array, fieldType)); + fields.add(createVectorField(array)); if (this.stored) { fields.add(createStoredFieldForFloatVector(name(), array)); } @@ -517,12 +532,11 @@ protected List getFieldsForFloatVector(final float[] array, final FieldTy * Function returns a list of fields to be indexed when the vector is byte type. * * @param array array of bytes - * @param fieldType {@link FieldType} * @return {@link List} of {@link Field} */ - protected List getFieldsForByteVector(final byte[] array, final FieldType fieldType) { + protected List getFieldsForByteVector(final byte[] array) { final List fields = new ArrayList<>(); - fields.add(new VectorField(name(), array, fieldType)); + fields.add(createVectorField(array)); if (this.stored) { fields.add(createStoredFieldForByteVector(name(), array)); } @@ -561,24 +575,14 @@ protected void validatePreparse() { protected void parseCreateField(ParseContext context, int dimension, VectorDataType vectorDataType) throws IOException { validatePreparse(); - if (VectorDataType.BINARY == vectorDataType) { - Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); - - if (bytesArrayOptional.isEmpty()) { - return; - } - final byte[] array = bytesArrayOptional.get(); - getVectorValidator().validateVector(array); - context.doc().addAll(getFieldsForByteVector(array, fieldType)); - } else if (VectorDataType.BYTE == vectorDataType) { + if (VectorDataType.BINARY == vectorDataType || VectorDataType.BYTE == vectorDataType) { Optional bytesArrayOptional = getBytesFromContext(context, dimension, vectorDataType); - if (bytesArrayOptional.isEmpty()) { return; } final byte[] array = bytesArrayOptional.get(); getVectorValidator().validateVector(array); - context.doc().addAll(getFieldsForByteVector(array, fieldType)); + context.doc().addAll(getFieldsForByteVector(array)); } else if (VectorDataType.FLOAT == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension); @@ -587,7 +591,7 @@ protected void parseCreateField(ParseContext context, int dimension, VectorDataT } final float[] array = floatsArrayOptional.get(); getVectorValidator().validateVector(array); - context.doc().addAll(getFieldsForFloatVector(array, fieldType)); + context.doc().addAll(getFieldsForFloatVector(array)); } else { throw new IllegalArgumentException( String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD) @@ -714,7 +718,6 @@ public static class Defaults { static { FIELD_TYPE.setTokenized(false); FIELD_TYPE.setIndexOptions(IndexOptions.NONE); - FIELD_TYPE.setDocValuesType(DocValuesType.BINARY); FIELD_TYPE.putAttribute(KNN_FIELD, "true"); // This attribute helps to determine knn field type FIELD_TYPE.freeze(); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java index 0caaf80ab0..9cd6bb4679 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtil.java @@ -238,6 +238,22 @@ static void validateIfKNNPluginEnabled() { } } + /** + * Prerequisite: Index should a knn index which is validated via index settings index.knn setting. This function + * assumes that caller has already validated that index is a KNN index. + * We will use LuceneKNNVectorsFormat when these below condition satisfy: + *
    + *
  1. Index is created with Version of opensearch >= 2.17
  2. + *
  3. Cluster setting is enabled to use Lucene KNNVectors format. This condition is temporary condition and will be + * removed before release.
  4. + *
+ * @param indexCreatedVersion {@link Version} + * @return true if vector field should use KNNVectorsFormat + */ + static boolean useLuceneKNNVectorsFormat(final Version indexCreatedVersion) { + return indexCreatedVersion.onOrAfter(Version.V_2_17_0) && KNNSettings.getIsLuceneVectorFormatEnabled(); + } + private static SpaceType getSpaceType(final Settings indexSettings, final VectorDataType vectorDataType) { String spaceType = indexSettings.get(KNNSettings.INDEX_KNN_SPACE_TYPE.getKey()); if (spaceType == null) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java index 665c35f6e3..7c3d942b6f 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LuceneFieldMapper.java @@ -17,7 +17,7 @@ import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.VectorSimilarityFunction; import org.opensearch.Version; import org.opensearch.common.Explicit; @@ -112,9 +112,9 @@ private LuceneFieldMapper(final KNNVectorFieldType mappedFieldType, final Create } @Override - protected List getFieldsForFloatVector(final float[] array, final FieldType fieldType) { + protected List getFieldsForFloatVector(final float[] array) { final List fieldsToBeAdded = new ArrayList<>(); - fieldsToBeAdded.add(new KnnVectorField(name(), array, fieldType)); + fieldsToBeAdded.add(new KnnFloatVectorField(name(), array, fieldType)); if (hasDocValues && vectorFieldType != null) { fieldsToBeAdded.add(new VectorField(name(), array, vectorFieldType)); @@ -127,7 +127,7 @@ protected List getFieldsForFloatVector(final float[] array, final FieldTy } @Override - protected List getFieldsForByteVector(final byte[] array, final FieldType fieldType) { + protected List getFieldsForByteVector(final byte[] array) { final List fieldsToBeAdded = new ArrayList<>(); fieldsToBeAdded.add(new KnnByteVectorField(name(), array, fieldType)); diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index 7a69c941b8..cc2c43386d 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -6,9 +6,12 @@ package org.opensearch.knn.index.mapper; import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.VectorEncoding; import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.KNNMethodContext; @@ -99,6 +102,7 @@ private MethodFieldMapper( indexVerision, originalKNNMethodContext ); + this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(indexCreatedVersion); KNNMappingConfig annConfig = mappedFieldType.getKnnMappingConfig(); KNNMethodContext knnMethodContext = annConfig.getKnnMethodContext() .orElseThrow(() -> new IllegalArgumentException("KNN method context cannot be empty")); @@ -118,6 +122,22 @@ private MethodFieldMapper( throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe)); } + if (useLuceneBasedVectorField) { + int adjustedDimension = mappedFieldType.vectorDataType == VectorDataType.BINARY + ? annConfig.getDimension() / 8 + : annConfig.getDimension(); + final VectorEncoding encoding = mappedFieldType.vectorDataType == VectorDataType.FLOAT + ? VectorEncoding.FLOAT32 + : VectorEncoding.BYTE; + fieldType.setVectorAttributes( + adjustedDimension, + encoding, + SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + ); + } else { + fieldType.setDocValuesType(DocValuesType.BINARY); + } + this.fieldType.freeze(); initValidatorsAndProcessors(knnMethodContext); knnMethodContext.getSpaceType().validateVectorDataType(vectorDataType); diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index a21a01a5dc..6c7e45e7e5 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -6,9 +6,12 @@ package org.opensearch.knn.index.mapper; import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.DocValuesType; +import org.apache.lucene.index.VectorEncoding; import org.opensearch.Version; import org.opensearch.common.Explicit; import org.opensearch.index.mapper.ParseContext; +import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.indices.ModelDao; @@ -102,7 +105,7 @@ private ModelFieldMapper( this.fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); this.fieldType.putAttribute(MODEL_ID, modelId); - this.fieldType.freeze(); + this.useLuceneBasedVectorField = KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(this.indexCreatedVersion); } @Override @@ -193,6 +196,21 @@ private void initPerDimensionProcessor() { protected void parseCreateField(ParseContext context) throws IOException { validatePreparse(); ModelMetadata modelMetadata = getModelMetadata(modelDao, modelId); + if (useLuceneBasedVectorField) { + int adjustedDimension = modelMetadata.getVectorDataType() == VectorDataType.BINARY + ? modelMetadata.getDimension() + : modelMetadata.getDimension() / 8; + final VectorEncoding encoding = modelMetadata.getVectorDataType() == VectorDataType.FLOAT + ? VectorEncoding.FLOAT32 + : VectorEncoding.BYTE; + fieldType.setVectorAttributes( + adjustedDimension, + encoding, + SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + ); + } else { + fieldType.setDocValuesType(DocValuesType.BINARY); + } parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getVectorDataType()); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index 00cc2b167c..bf2c33bf9d 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -8,7 +8,9 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.document.KnnVectorField; +import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.Query; @@ -89,8 +91,6 @@ * Test used for testing Codecs */ public class KNNCodecTestCase extends KNNTestCase { - - private static final Codec ACTUAL_CODEC = KNNCodecVersion.current().getDefaultKnnCodecSupplier().get(); private static final FieldType sampleFieldType; static { KNNMethodContext knnMethodContext = new KNNMethodContext( @@ -109,6 +109,7 @@ public class KNNCodecTestCase extends KNNTestCase { } sampleFieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + sampleFieldType.setDocValuesType(DocValuesType.BINARY); sampleFieldType.putAttribute(KNNVectorFieldMapper.KNN_FIELD, "true"); sampleFieldType.putAttribute(KNNConstants.KNN_ENGINE, knnMethodContext.getKnnEngine().getName()); sampleFieldType.putAttribute(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); @@ -259,6 +260,7 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio iwc.setCodec(codec); FieldType fieldType = new FieldType(KNNVectorFieldMapper.Defaults.FIELD_TYPE); + fieldType.setDocValuesType(DocValuesType.BINARY); fieldType.putAttribute(KNNConstants.MODEL_ID, modelId); fieldType.freeze(); @@ -356,9 +358,9 @@ public void testKnnVectorIndex( /** * Add doc with field "test_vector_one" */ - final FieldType luceneFieldType = KnnVectorField.createFieldType(3, VectorSimilarityFunction.EUCLIDEAN); + final FieldType luceneFieldType = KnnFloatVectorField.createFieldType(3, VectorSimilarityFunction.EUCLIDEAN); float[] array = { 1.0f, 3.0f, 4.0f }; - KnnVectorField vectorField = new KnnVectorField(FIELD_NAME_ONE, array, luceneFieldType); + KnnFloatVectorField vectorField = new KnnFloatVectorField(FIELD_NAME_ONE, array, luceneFieldType); RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); Document doc = new Document(); doc.add(vectorField); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index f06ff79353..e1d8421125 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -7,11 +7,14 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.util.BytesRef; +import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.Explicit; @@ -27,14 +30,14 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.ParseContext; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.engine.KNNMethodContext; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -79,6 +82,7 @@ import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.clipVectorValueToFP16Range; import static org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil.validateFP16VectorValue; +@Log4j2 public class KNNVectorFieldMapperTests extends KNNTestCase { private static final String TEST_FIELD_NAME = "test-field-name"; @@ -739,6 +743,112 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { expectThrows(IllegalArgumentException.class, () -> knnVectorFieldMapper1.merge(knnVectorFieldMapper3)); } + @SneakyThrows + public void testMethodFieldMapperParseCreateField_validInput_thenDifferentFieldTypes() { + MockedStatic utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class); + for (VectorDataType dataType : VectorDataType.values()) { + log.info("Vector Data Type is : {}", dataType); + int dimension = dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION; + final MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, Collections.emptyMap()); + methodComponentContext.setIndexVersion(CURRENT); + SpaceType spaceType = VectorDataType.BINARY == dataType ? SpaceType.DEFAULT_BINARY : SpaceType.INNER_PRODUCT; + final KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, spaceType, methodComponentContext); + + ParseContext.Document document = new ParseContext.Document(); + ContentPath contentPath = new ContentPath(); + ParseContext parseContext = mock(ParseContext.class); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(true); + MethodFieldMapper methodFieldMapper = Mockito.spy( + MethodFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + TEST_FIELD_NAME, + Collections.emptyMap(), + dataType, + dimension, + knnMethodContext, + knnMethodContext, + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + false, + CURRENT + ) + ); + + if (dataType == VectorDataType.BINARY) { + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper) + .getBytesFromContext(parseContext, TEST_DIMENSION * 8, dataType); + } else if (dataType == VectorDataType.BYTE) { + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper).getBytesFromContext(parseContext, TEST_DIMENSION, dataType); + } else { + doReturn(Optional.of(TEST_VECTOR)).when(methodFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + } + + methodFieldMapper.parseCreateField(parseContext, dimension, dataType); + + List fields = document.getFields(); + assertEquals(1, fields.size()); + IndexableField field1 = fields.get(0); + if (dataType == VectorDataType.FLOAT) { + assertTrue(field1 instanceof KnnFloatVectorField); + assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.FLOAT32); + } else { + assertTrue(field1 instanceof KnnByteVectorField); + assertEquals(field1.fieldType().vectorEncoding(), VectorEncoding.BYTE); + } + + assertEquals(field1.fieldType().vectorDimension(), TEST_DIMENSION); + assertEquals( + field1.fieldType().vectorSimilarityFunction(), + SpaceType.DEFAULT.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() + ); + + utilMockedStatic.when(() -> KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Mockito.any())).thenReturn(false); + + document = new ParseContext.Document(); + contentPath = new ContentPath(); + when(parseContext.doc()).thenReturn(document); + when(parseContext.path()).thenReturn(contentPath); + methodFieldMapper = Mockito.spy( + MethodFieldMapper.createFieldMapper( + TEST_FIELD_NAME, + TEST_FIELD_NAME, + Collections.emptyMap(), + dataType, + dimension, + knnMethodContext, + knnMethodContext, + FieldMapper.MultiFields.empty(), + FieldMapper.CopyTo.empty(), + new Explicit<>(true, true), + false, + false, + CURRENT + ) + ); + + if (dataType == VectorDataType.FLOAT) { + doReturn(Optional.of(TEST_VECTOR)).when(methodFieldMapper).getFloatsFromContext(parseContext, TEST_DIMENSION); + } else { + doReturn(Optional.of(TEST_BYTE_VECTOR)).when(methodFieldMapper) + .getBytesFromContext(parseContext, dataType == VectorDataType.BINARY ? TEST_DIMENSION * 8 : TEST_DIMENSION, dataType); + } + + methodFieldMapper.parseCreateField(parseContext, dimension, dataType); + fields = document.getFields(); + assertEquals(1, fields.size()); + field1 = fields.get(0); + assertTrue(field1 instanceof VectorField); + } + // making sure to close the static mock to ensure that for tests running on this thread are not impacted by + // this mocking + utilMockedStatic.close(); + } + @SneakyThrows public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { // Create a lucene field mapper that creates a binary doc values field as well as KnnVectorField @@ -765,22 +875,22 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { doNothing().when(luceneFieldMapper).validatePreparse(); luceneFieldMapper.parseCreateField(parseContext, TEST_DIMENSION, VectorDataType.FLOAT); - // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnVectorField + // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnFloatVectorField List fields = document.getFields(); assertEquals(2, fields.size()); IndexableField field1 = fields.get(0); IndexableField field2 = fields.get(1); VectorField vectorField; - KnnVectorField knnVectorField; + KnnFloatVectorField knnVectorField; if (field1 instanceof VectorField) { assertTrue(field2 instanceof KnnVectorField); vectorField = (VectorField) field1; - knnVectorField = (KnnVectorField) field2; + knnVectorField = (KnnFloatVectorField) field2; } else { - assertTrue(field1 instanceof KnnVectorField); + assertTrue(field1 instanceof KnnFloatVectorField); assertTrue(field2 instanceof VectorField); - knnVectorField = (KnnVectorField) field1; + knnVectorField = (KnnFloatVectorField) field1; vectorField = (VectorField) field2; } @@ -821,8 +931,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { fields = document.getFields(); assertEquals(1, fields.size()); IndexableField field = fields.get(0); - assertTrue(field instanceof KnnVectorField); - knnVectorField = (KnnVectorField) field; + assertTrue(field instanceof KnnFloatVectorField); + knnVectorField = (KnnFloatVectorField) field; assertArrayEquals(TEST_VECTOR, knnVectorField.vectorValue(), 0.001f); } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java index 8ace5557ef..a80110181a 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperUtilTests.java @@ -13,8 +13,13 @@ import org.apache.lucene.document.StoredField; import org.apache.lucene.util.BytesRef; +import org.junit.Assert; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.Version; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.SpaceType; @@ -105,6 +110,23 @@ public void testValidateVectorDataType_whenFloat_thenValid() { validateValidateVectorDataType(KNNEngine.NMSLIB, KNNConstants.METHOD_HNSW, VectorDataType.FLOAT, null); } + public void testUseLuceneKNNVectorsFormat_withDifferentInputs_thenSuccess() { + final KNNSettings knnSettings = mock(KNNSettings.class); + final MockedStatic mockedStatic = Mockito.mockStatic(KNNSettings.class); + mockedStatic.when(KNNSettings::state).thenReturn(knnSettings); + + mockedStatic.when(KNNSettings::getIsLuceneVectorFormatEnabled).thenReturn(false); + Assert.assertFalse(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_2_16_0)); + Assert.assertFalse(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_3_0_0)); + + mockedStatic.when(KNNSettings::getIsLuceneVectorFormatEnabled).thenReturn(true); + Assert.assertTrue(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_2_17_0)); + Assert.assertTrue(KNNVectorFieldMapperUtil.useLuceneKNNVectorsFormat(Version.V_3_0_0)); + // making sure to close the static mock to ensure that for tests running on this thread are not impacted by + // this mocking + mockedStatic.close(); + } + private void validateValidateVectorDataType( final KNNEngine knnEngine, final String methodName, From f5ba77114ef662e91a8ce26838159f383931912c Mon Sep 17 00:00:00 2001 From: Doo Yong Kim <0ctopus13prime@gmail.com> Date: Mon, 12 Aug 2024 12:16:07 -0700 Subject: [PATCH 6/6] Disallow invalid characters for physical file name to be included within vector field name. (#1936) * Block a vector field to have invalid characters for a physical file name. Signed-off-by: Dooyong Kim * Block a vector field to have invalid characters for a physical file name. Signed-off-by: Dooyong Kim --------- Signed-off-by: Dooyong Kim Signed-off-by: Doo Yong Kim <0ctopus13prime@gmail.com> Co-authored-by: Dooyong Kim --- CHANGELOG.md | 1 + .../index/mapper/KNNVectorFieldMapper.java | 31 +++++++++++++++++++ .../mapper/KNNVectorFieldMapperTests.java | 18 +++++++++++ 3 files changed, 50 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f1dc5b14d8..eb8427b1f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917) * Fix graph merge stats size calculation [#1844](https://github.com/opensearch-project/k-NN/pull/1844) * Integrate Lucene Vector field with native engines to use KNNVectorFormat during segment creation [#1945](https://github.com/opensearch-project/k-NN/pull/1945) +* Disallow a vector field to have an invalid character for a physical file name. [#1936] (https://github.com/opensearch-project/k-NN/pull/1936) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 5d4d3ca58c..94756f5956 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -13,6 +13,8 @@ import java.util.Map; import java.util.Optional; import java.util.function.Supplier; +import java.util.stream.Collectors; + import lombok.extern.log4j.Log4j2; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; @@ -23,6 +25,7 @@ import org.opensearch.common.Explicit; import org.opensearch.common.ValidationException; import org.opensearch.common.xcontent.support.XContentMapValues; +import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -219,6 +222,8 @@ private void validateFlatMapper() { @Override public KNNVectorFieldMapper build(BuilderContext context) { + validateFullFieldName(context); + final MultiFields multiFieldsBuilder = this.multiFieldsBuilder.build(this, context); final CopyTo copyToBuilder = copyTo.build(); final Explicit ignoreMalformed = ignoreMalformed(context); @@ -413,6 +418,32 @@ private KNNEngine validateDimensions(final KNNMethodContext knnMethodContext, fi } return knnEngine; } + + /** + * Validate whether provided full field name contain any invalid characters for physical file name. + * At the moment, we use a field name as a part of file name while we throw an exception + * if a physical file name contains any invalid characters when creating snapshot. + * To prevent from this happening, we restrict vector field name and make sure generated file to have a valid name. + * + * @param context : Builder context to have field name info. + */ + private void validateFullFieldName(final BuilderContext context) { + final String fullFieldName = buildFullName(context); + for (char ch : fullFieldName.toCharArray()) { + if (Strings.INVALID_FILENAME_CHARS.contains(ch)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Vector field name must not include invalid characters of %s. " + + "Provided field name=[%s] had a disallowed character [%c]", + Strings.INVALID_FILENAME_CHARS.stream().map(c -> "'" + c + "'").collect(Collectors.toList()), + fullFieldName, + ch + ) + ); + } + } + } } public static class TypeParser implements Mapper.TypeParser { diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index e1d8421125..b3139fa5c7 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -22,6 +22,7 @@ import org.opensearch.common.settings.IndexScopedSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexSettings; import org.opensearch.index.mapper.ContentPath; @@ -1171,6 +1172,23 @@ public void testBuilder_whenBinaryWithLegacyKNNEnabled_thenException() { assertTrue(ex.getMessage(), ex.getMessage().contains("is not supported for")); } + public void testBuild_whenInvalidCharsInFieldName_thenThrowException() { + for (char disallowChar : Strings.INVALID_FILENAME_CHARS) { + // When an invalid vector field name was given. + final String invalidVectorFieldName = "fieldname" + disallowChar; + + // Prepare context. + Settings settings = Settings.builder().put(settings(CURRENT).build()).put(KNN_INDEX, true).build(); + Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); + + // IllegalArgumentException should be thrown. + Exception e = assertThrows(IllegalArgumentException.class, () -> { + new KNNVectorFieldMapper.Builder(invalidVectorFieldName, null, CURRENT, null).build(builderContext); + }); + assertTrue(e.getMessage(), e.getMessage().contains("Vector field name must not include")); + } + } + private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder( VectorDataType vectorDataType ) {