diff --git a/CHANGELOG.md b/CHANGELOG.md index 867fe9683..53bab7670 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,3 +26,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824) * Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913) * Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920) +* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNLibrary.java b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNLibrary.java index 9d83f42a8..92e34be7c 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNLibrary.java @@ -24,10 +24,18 @@ public abstract class AbstractKNNLibrary implements KNNLibrary { protected final String version; @Override - public EngineSpecificMethodContext getMethodContext(String methodName) { + public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) { validateMethodExists(methodName); KNNMethod method = methods.get(methodName); - return method.getEngineSpecificMethodContext(); + return method.getKNNLibrarySearchContext(); + } + + @Override + public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) { + String method = knnMethodContext.getMethodComponentContext().getName(); + validateMethodExists(method); + KNNMethod knnMethod = methods.get(method); + return knnMethod.getKNNLibraryIndexingContext(knnMethodContext); } @Override @@ -51,14 +59,6 @@ public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { return methods.get(methodName).isTrainingRequired(knnMethodContext); } - @Override - public Map getMethodAsMap(KNNMethodContext knnMethodContext) { - String method = knnMethodContext.getMethodComponentContext().getName(); - validateMethodExists(method); - KNNMethod knnMethod = methods.get(method); - return knnMethod.getAsMap(knnMethodContext); - } - private void validateMethodExists(String methodName) { KNNMethod method = methods.get(methodName); if (method == null) { diff --git a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java index cbed5fe40..6e57e6913 100644 --- a/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java @@ -27,7 +27,7 @@ public abstract class AbstractKNNMethod implements KNNMethod { protected final MethodComponent methodComponent; protected final Set spaces; - protected final EngineSpecificMethodContext engineSpecificMethodContext; + protected final KNNLibrarySearchContext knnLibrarySearchContext; @Override public boolean isSpaceTypeSupported(SpaceType space) { @@ -106,14 +106,14 @@ public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension } @Override - public Map getAsMap(KNNMethodContext knnMethodContext) { + public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) { Map parameterMap = new HashMap<>(methodComponent.getAsMap(knnMethodContext.getMethodComponentContext())); parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue()); - return parameterMap; + return KNNLibraryIndexingContextImpl.builder().parameters(parameterMap).build(); } @Override - public EngineSpecificMethodContext getEngineSpecificMethodContext() { - return engineSpecificMethodContext; + public KNNLibrarySearchContext getKNNLibrarySearchContext() { + return knnLibrarySearchContext; } } diff --git a/src/main/java/org/opensearch/knn/index/engine/DefaultHnswContext.java b/src/main/java/org/opensearch/knn/index/engine/DefaultHnswSearchContext.java similarity index 91% rename from src/main/java/org/opensearch/knn/index/engine/DefaultHnswContext.java rename to src/main/java/org/opensearch/knn/index/engine/DefaultHnswSearchContext.java index ef1d960d4..ecc11f338 100644 --- a/src/main/java/org/opensearch/knn/index/engine/DefaultHnswContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/DefaultHnswSearchContext.java @@ -14,7 +14,7 @@ /** * Default HNSW context for all engines. Have a different implementation if engine context differs. */ -public final class DefaultHnswContext implements EngineSpecificMethodContext { +public final class DefaultHnswSearchContext implements KNNLibrarySearchContext { private final Map> supportedMethodParameters = ImmutableMap.>builder() .put(MethodParameter.EF_SEARCH.getName(), new Parameter.IntegerParameter(MethodParameter.EF_SEARCH.getName(), null, value -> true)) diff --git a/src/main/java/org/opensearch/knn/index/engine/DefaultIVFContext.java b/src/main/java/org/opensearch/knn/index/engine/DefaultIVFSearchContext.java similarity index 90% rename from src/main/java/org/opensearch/knn/index/engine/DefaultIVFContext.java rename to src/main/java/org/opensearch/knn/index/engine/DefaultIVFSearchContext.java index a1a474420..cc612bf8c 100644 --- a/src/main/java/org/opensearch/knn/index/engine/DefaultIVFContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/DefaultIVFSearchContext.java @@ -11,7 +11,7 @@ import java.util.Map; -public final class DefaultIVFContext implements EngineSpecificMethodContext { +public final class DefaultIVFSearchContext implements KNNLibrarySearchContext { private final Map> supportedMethodParameters = ImmutableMap.>builder() .put(MethodParameter.NPROBE.getName(), new Parameter.IntegerParameter(MethodParameter.NPROBE.getName(), null, value -> true)) diff --git a/src/main/java/org/opensearch/knn/index/engine/EngineSpecificMethodContext.java b/src/main/java/org/opensearch/knn/index/engine/EngineSpecificMethodContext.java deleted file mode 100644 index a043bd9cd..000000000 --- a/src/main/java/org/opensearch/knn/index/engine/EngineSpecificMethodContext.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.index.engine; - -import org.opensearch.knn.index.engine.model.QueryContext; - -import java.util.Collections; -import java.util.Map; - -/** - * Holds context related to a method for a particular engine - * Each engine can have a specific set of parameters that it supports during index and build time. This context holds - * the information for each engine method combination. - * - * TODO: Move KnnMethod in here - */ -public interface EngineSpecificMethodContext { - - Map> supportedMethodParameters(QueryContext ctx); - - EngineSpecificMethodContext EMPTY = ctx -> Collections.emptyMap(); -} diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java index f2f2bab35..c7b271783 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNEngine.java @@ -145,11 +145,6 @@ public String getCompoundExtension() { return knnLibrary.getCompoundExtension(); } - @Override - public EngineSpecificMethodContext getMethodContext(String methodName) { - return knnLibrary.getMethodContext(methodName); - } - @Override public float score(float rawScore, SpaceType spaceType) { return knnLibrary.score(rawScore, spaceType); @@ -181,8 +176,13 @@ public boolean isTrainingRequired(KNNMethodContext knnMethodContext) { } @Override - public Map getMethodAsMap(KNNMethodContext knnMethodContext) { - return knnLibrary.getMethodAsMap(knnMethodContext); + public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) { + return knnLibrary.getKNNLibraryIndexingContext(knnMethodContext); + } + + @Override + public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) { + return knnLibrary.getKNNLibrarySearchContext(methodName); } @Override diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java index c47f39e03..96d492307 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java @@ -11,7 +11,6 @@ import java.util.Collections; import java.util.List; -import java.util.Map; /** * KNNLibrary is an interface that helps the plugin communicate with k-NN libraries @@ -41,13 +40,6 @@ public interface KNNLibrary { */ String getCompoundExtension(); - /** - * Gets metadata related to methods supported by the library - * @param methodName - * @return - */ - EngineSpecificMethodContext getMethodContext(String methodName); - /** * Generate the Lucene score from the rawScore returned by the library. With k-NN, often times the library * will return a score where the lower the score, the better the result. This is the opposite of how Lucene scores @@ -116,12 +108,20 @@ public interface KNNLibrary { int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension); /** - * Generate method as map that can be used to configure the knn index from the jni + * Get the context from the library needed to build the index. * - * @param knnMethodContext to generate parameter map from + * @param knnMethodContext to get build context for * @return parameter map */ - Map getMethodAsMap(KNNMethodContext knnMethodContext); + KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext); + + /** + * Gets metadata related to methods supported by the library + * + * @param methodName name of method + * @return KNNLibrarySearchContext + */ + KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName); /** * Getter for initialized diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java new file mode 100644 index 000000000..d00b7c436 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContext.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import java.util.Collections; +import java.util.Map; + +/** + * Context a library gives to build one of its indices + */ +public interface KNNLibraryIndexingContext { + /** + * Get map of parameters that get passed to the library to build the index + * + * @return Map of parameters + */ + Map getLibraryParameters(); + + KNNLibraryIndexingContext EMPTY = Collections::emptyMap; +} diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java new file mode 100644 index 000000000..b7c775261 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibraryIndexingContextImpl.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import lombok.Builder; + +import java.util.Map; + +/** + * Simple implementation of {@link KNNLibraryIndexingContext} + */ +@Builder +public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext { + + private Map parameters; + + @Override + public Map getLibraryParameters() { + return parameters; + } +} diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNLibrarySearchContext.java b/src/main/java/org/opensearch/knn/index/engine/KNNLibrarySearchContext.java new file mode 100644 index 000000000..b769745f6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/engine/KNNLibrarySearchContext.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.engine; + +import org.opensearch.knn.index.engine.model.QueryContext; + +import java.util.Collections; +import java.util.Map; + +/** + * Holds the context needed to search a knn library. + */ +public interface KNNLibrarySearchContext { + + /** + * Returns supported parameters for the library. + * + * @param ctx QueryContext + * @return parameters supported by the library + */ + Map> supportedMethodParameters(QueryContext ctx); + + KNNLibrarySearchContext EMPTY = ctx -> Collections.emptyMap(); +} diff --git a/src/main/java/org/opensearch/knn/index/engine/KNNMethod.java b/src/main/java/org/opensearch/knn/index/engine/KNNMethod.java index ea556e8bf..326e5c1e0 100644 --- a/src/main/java/org/opensearch/knn/index/engine/KNNMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/KNNMethod.java @@ -9,8 +9,6 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.training.VectorSpaceInfo; -import java.util.Map; - /** * KNNMethod defines the structure of a method supported by a particular k-NN library. It is used to validate * the KNNMethodContext passed in by the user, where the KNNMethodContext provides the configuration that the user may @@ -61,17 +59,17 @@ public interface KNNMethod { int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension); /** - * Parse knnMethodContext into a map that the library can use to configure the index + * Parse knnMethodContext into context that the library can use to build the index * - * @param knnMethodContext from which to generate map - * @return KNNMethod as a map + * @param knnMethodContext to generate the context for + * @return KNNLibraryIndexingContext */ - Map getAsMap(KNNMethodContext knnMethodContext); + KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext); /** - * Get the method context for a particular method + * Get the search context for a particular method * - * @return EngineSpecificMethodContext for the method + * @return KNNLibrarySearchContext */ - EngineSpecificMethodContext getEngineSpecificMethodContext(); + KNNLibrarySearchContext getKNNLibrarySearchContext(); } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java index ac05afc33..382a71741 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java @@ -9,7 +9,7 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.AbstractKNNMethod; -import org.opensearch.knn.index.engine.DefaultHnswContext; +import org.opensearch.knn.index.engine.DefaultHnswSearchContext; import org.opensearch.knn.index.engine.Encoder; import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.MethodComponentContext; @@ -51,7 +51,7 @@ public class FaissHNSWMethod extends AbstractKNNMethod { * @see AbstractKNNMethod */ public FaissHNSWMethod() { - super(initMethodComponent(), Set.copyOf(SUPPORTED_SPACES), new DefaultHnswContext()); + super(initMethodComponent(), Set.copyOf(SUPPORTED_SPACES), new DefaultHnswSearchContext()); } private static MethodComponent initMethodComponent() { diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java index 813cb6e9e..aa05e8c87 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java @@ -8,7 +8,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.AbstractKNNMethod; -import org.opensearch.knn.index.engine.DefaultIVFContext; +import org.opensearch.knn.index.engine.DefaultIVFSearchContext; import org.opensearch.knn.index.engine.Encoder; import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.MethodComponentContext; @@ -55,7 +55,7 @@ public class FaissIVFMethod extends AbstractKNNMethod { * @see AbstractKNNMethod */ public FaissIVFMethod() { - super(initMethodComponent(), Set.copyOf(SUPPORTED_SPACES), new DefaultIVFContext()); + super(initMethodComponent(), Set.copyOf(SUPPORTED_SPACES), new DefaultIVFSearchContext()); } private static MethodComponent initMethodComponent() { diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java index 0419a5440..c6fcdb7c4 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWMethod.java @@ -49,7 +49,7 @@ public class LuceneHNSWMethod extends AbstractKNNMethod { * @see AbstractKNNMethod */ public LuceneHNSWMethod() { - super(initMethodComponent(), Set.copyOf(SUPPORTED_SPACES), new LuceneHNSWContext()); + super(initMethodComponent(), Set.copyOf(SUPPORTED_SPACES), new LuceneHNSWSearchContext()); } private static MethodComponent initMethodComponent() { diff --git a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWContext.java b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWSearchContext.java similarity index 88% rename from src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWContext.java rename to src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWSearchContext.java index 808aa66b0..2c4da27df 100644 --- a/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWContext.java +++ b/src/main/java/org/opensearch/knn/index/engine/lucene/LuceneHNSWSearchContext.java @@ -6,7 +6,7 @@ package org.opensearch.knn.index.engine.lucene; import com.google.common.collect.ImmutableMap; -import org.opensearch.knn.index.engine.EngineSpecificMethodContext; +import org.opensearch.knn.index.engine.KNNLibrarySearchContext; import org.opensearch.knn.index.engine.Parameter; import org.opensearch.knn.index.engine.model.QueryContext; import org.opensearch.knn.index.query.request.MethodParameter; @@ -14,7 +14,7 @@ import java.util.Collections; import java.util.Map; -public class LuceneHNSWContext implements EngineSpecificMethodContext { +public class LuceneHNSWSearchContext implements KNNLibrarySearchContext { private final Map> supportedMethodParameters = ImmutableMap.>builder() .put(MethodParameter.EF_SEARCH.getName(), new Parameter.IntegerParameter(MethodParameter.EF_SEARCH.getName(), null, value -> true)) diff --git a/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibHNSWMethod.java b/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibHNSWMethod.java index 39c5d5f24..e8e27bcd6 100644 --- a/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibHNSWMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/nmslib/NmslibHNSWMethod.java @@ -8,7 +8,7 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.AbstractKNNMethod; -import org.opensearch.knn.index.engine.DefaultHnswContext; +import org.opensearch.knn.index.engine.DefaultHnswSearchContext; import org.opensearch.knn.index.engine.MethodComponent; import org.opensearch.knn.index.engine.Parameter; @@ -39,7 +39,7 @@ public class NmslibHNSWMethod extends AbstractKNNMethod { * @see AbstractKNNMethod */ public NmslibHNSWMethod() { - super(initMethodComponent(), Set.copyOf(SUPPORTED_SPACES), new DefaultHnswContext()); + super(initMethodComponent(), Set.copyOf(SUPPORTED_SPACES), new DefaultHnswSearchContext()); } private static MethodComponent initMethodComponent() { diff --git a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java index dba37c927..b15ab1489 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/MethodFieldMapper.java @@ -12,6 +12,7 @@ import org.opensearch.knn.index.engine.KNNEngine; import java.io.IOException; +import java.util.Map; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; @@ -58,10 +59,8 @@ public class MethodFieldMapper extends KNNVectorFieldMapper { this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName()); try { - this.fieldType.putAttribute( - PARAMETERS, - XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext)).toString() - ); + Map libParams = knnEngine.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); + this.fieldType.putAttribute(PARAMETERS, XContentFactory.jsonBuilder().map(libParams).toString()); } catch (IOException ioe) { throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe)); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index f860b282b..b05938c4a 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -32,7 +32,7 @@ import org.opensearch.knn.index.VectorQueryType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; -import org.opensearch.knn.index.engine.EngineSpecificMethodContext; +import org.opensearch.knn.index.engine.KNNLibrarySearchContext; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; @@ -373,7 +373,7 @@ protected Query doToQuery(QueryShardContext context) { final String method = methodComponentContext != null ? methodComponentContext.getName() : null; if (StringUtils.isNotBlank(method)) { - final EngineSpecificMethodContext engineSpecificMethodContext = knnEngine.getMethodContext(method); + final KNNLibrarySearchContext engineSpecificMethodContext = knnEngine.getKNNLibrarySearchContext(method); QueryContext queryContext = new QueryContext(vectorQueryType); ValidationException validationException = validateParameters( engineSpecificMethodContext.supportedMethodParameters(queryContext), diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index c0cfb72dc..3bdb50ad0 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -182,7 +182,10 @@ public void run() { throw new RuntimeException("Unable to load training data into memory: allocation is already closed"); } setVersionInKnnMethodContext(); - Map trainParameters = model.getModelMetadata().getKnnEngine().getMethodAsMap(knnMethodContext); + Map trainParameters = model.getModelMetadata() + .getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext) + .getLibraryParameters(); trainParameters.put( KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) diff --git a/src/test/java/org/opensearch/knn/KNNTestCase.java b/src/test/java/org/opensearch/knn/KNNTestCase.java index 2d7bb6d2c..56c129546 100644 --- a/src/test/java/org/opensearch/knn/KNNTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNTestCase.java @@ -12,7 +12,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.knn.index.KNNSettings; -import org.opensearch.knn.index.engine.EngineSpecificMethodContext; +import org.opensearch.knn.index.engine.KNNLibrarySearchContext; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.plugin.stats.KNNCounter; import org.opensearch.core.common.bytes.BytesReference; @@ -32,7 +32,7 @@ */ public class KNNTestCase extends OpenSearchTestCase { - protected static final EngineSpecificMethodContext EMPTY_ENGINE_SPECIFIC_CONTEXT = ctx -> Map.of(); + protected static final KNNLibrarySearchContext EMPTY_ENGINE_SPECIFIC_CONTEXT = ctx -> Map.of(); @Mock protected ClusterService clusterService; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index ce8fad384..e1ebc5708 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -208,7 +208,9 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException new MethodComponentContext(METHOD_HNSW, ImmutableMap.of(METHOD_PARAMETER_M, 16, METHOD_PARAMETER_EF_CONSTRUCTION, 512)) ); - String parameterString = XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext)).toString(); + String parameterString = XContentFactory.jsonBuilder() + .map(knnEngine.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters()) + .toString(); FieldInfo[] fieldInfoArray = new FieldInfo[] { KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) @@ -328,7 +330,9 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException ); knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); - String parameterString = XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext)).toString(); + String parameterString = XContentFactory.jsonBuilder() + .map(knnEngine.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters()) + .toString(); FieldInfo[] fieldInfoArray = new FieldInfo[] { KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) @@ -393,7 +397,9 @@ public void testAddKNNBinaryField_whenFaissBinary_thenAdded() throws IOException ); knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); - String parameterString = XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext)).toString(); + String parameterString = XContentFactory.jsonBuilder() + .map(knnEngine.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters()) + .toString(); FieldInfo[] fieldInfoArray = new FieldInfo[] { KNNCodecTestUtil.FieldInfoBuilder.builder(fieldName) diff --git a/src/test/java/org/opensearch/knn/index/engine/AbstractKNNLibraryTests.java b/src/test/java/org/opensearch/knn/index/engine/AbstractKNNLibraryTests.java index 5dd185112..c6ab9ccdb 100644 --- a/src/test/java/org/opensearch/knn/index/engine/AbstractKNNLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/AbstractKNNLibraryTests.java @@ -29,7 +29,7 @@ public class AbstractKNNLibraryTests extends KNNTestCase { private final static KNNMethod INVALID_METHOD_THROWS_VALIDATION = new AbstractKNNMethod( MethodComponent.Builder.builder(INVALID_METHOD_THROWS_VALIDATION_NAME).build(), Set.of(SpaceType.DEFAULT), - new DefaultHnswContext() + new DefaultHnswSearchContext() ) { @Override public ValidationException validate(KNNMethodContext knnMethodContext) { @@ -37,7 +37,7 @@ public ValidationException validate(KNNMethodContext knnMethodContext) { } }; private final static String VALID_METHOD_NAME = "test-method-2"; - private final static EngineSpecificMethodContext VALID_METHOD_CONTEXT = ctx -> ImmutableMap.of( + private final static KNNLibrarySearchContext VALID_METHOD_CONTEXT = ctx -> ImmutableMap.of( "myparameter", new Parameter.BooleanParameter("myparameter", null, value -> true) ); @@ -75,15 +75,15 @@ public void testValidateMethod() throws IOException { public void testEngineSpecificMethods() { QueryContext engineSpecificMethodContext = new QueryContext(VectorQueryType.K); - assertNotNull(TEST_LIBRARY.getMethodContext(VALID_METHOD_NAME)); + assertNotNull(TEST_LIBRARY.getKNNLibrarySearchContext(VALID_METHOD_NAME)); assertTrue( - TEST_LIBRARY.getMethodContext(VALID_METHOD_NAME) + TEST_LIBRARY.getKNNLibrarySearchContext(VALID_METHOD_NAME) .supportedMethodParameters(engineSpecificMethodContext) .containsKey("myparameter") ); } - public void testGetMethodAsMap() { + public void testGetKNNLibraryIndexingContext() { // Check that map is expected Map expectedMap = new HashMap<>(VALID_EXPECTED_MAP); expectedMap.put(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); @@ -92,7 +92,7 @@ public void testGetMethodAsMap() { SpaceType.DEFAULT, new MethodComponentContext(VALID_METHOD_NAME, Collections.emptyMap()) ); - assertEquals(expectedMap, TEST_LIBRARY.getMethodAsMap(knnMethodContext)); + assertEquals(expectedMap, TEST_LIBRARY.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters()); // Check when invalid method is passed in KNNMethodContext invalidKnnMethodContext = new KNNMethodContext( @@ -100,7 +100,7 @@ public void testGetMethodAsMap() { SpaceType.DEFAULT, new MethodComponentContext("invalid", Collections.emptyMap()) ); - expectThrows(IllegalArgumentException.class, () -> TEST_LIBRARY.getMethodAsMap(invalidKnnMethodContext)); + expectThrows(IllegalArgumentException.class, () -> TEST_LIBRARY.getKNNLibraryIndexingContext(invalidKnnMethodContext)); } private static class TestAbstractKNNLibrary extends AbstractKNNLibrary { diff --git a/src/test/java/org/opensearch/knn/index/engine/AbstractKNNMethodTests.java b/src/test/java/org/opensearch/knn/index/engine/AbstractKNNMethodTests.java index 981024e6c..2c739c6f7 100644 --- a/src/test/java/org/opensearch/knn/index/engine/AbstractKNNMethodTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/AbstractKNNMethodTests.java @@ -25,11 +25,7 @@ public class AbstractKNNMethodTests extends KNNTestCase { private static class TestKNNMethod extends AbstractKNNMethod { - public TestKNNMethod( - MethodComponent methodComponent, - Set spaces, - EngineSpecificMethodContext engineSpecificMethodContext - ) { + public TestKNNMethod(MethodComponent methodComponent, Set spaces, KNNLibrarySearchContext engineSpecificMethodContext) { super(methodComponent, spaces, engineSpecificMethodContext); } } @@ -143,7 +139,7 @@ public void testValidateWithData() throws IOException { assertNull(knnMethod.validateWithData(knnMethodContext3, testVectorSpaceInfo)); } - public void testGetAsMap() { + public void testGetKNNLibraryIndexingContext() { SpaceType spaceType = SpaceType.DEFAULT; String methodName = "test-method"; Map generatedMap = ImmutableMap.of("test-key", "test-value"); @@ -158,18 +154,20 @@ public void testGetAsMap() { assertEquals( expectedMap, - knnMethod.getAsMap(new KNNMethodContext(KNNEngine.DEFAULT, spaceType, new MethodComponentContext(methodName, generatedMap))) + knnMethod.getKNNLibraryIndexingContext( + new KNNMethodContext(KNNEngine.DEFAULT, spaceType, new MethodComponentContext(methodName, generatedMap)) + ).getLibraryParameters() ); } - public void testGetEngineSpecificMethodContext() { + public void testGetKNNLibrarySearchContext() { String methodName = "test-method"; - EngineSpecificMethodContext engineSpecificMethodContext = new DefaultHnswContext(); + KNNLibrarySearchContext knnLibrarySearchContext = new DefaultHnswSearchContext(); KNNMethod knnMethod = new TestKNNMethod( MethodComponent.Builder.builder(methodName).build(), Set.of(SpaceType.L2), - engineSpecificMethodContext + knnLibrarySearchContext ); - assertEquals(engineSpecificMethodContext, knnMethod.getEngineSpecificMethodContext()); + assertEquals(knnLibrarySearchContext, knnMethod.getKNNLibrarySearchContext()); } } diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java index 366f4ec77..af5086491 100644 --- a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java @@ -55,7 +55,7 @@ public void testGetMethodAsMap_whenMethodIsHNSWFlat_thenCreateCorrectIndexDescri KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); - Map map = Faiss.INSTANCE.getMethodAsMap(knnMethodContext); + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); @@ -84,7 +84,7 @@ public void testGetMethodAsMap_whenMethodIsHNSWPQ_thenCreateCorrectIndexDescript KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); - Map map = Faiss.INSTANCE.getMethodAsMap(knnMethodContext); + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); @@ -113,7 +113,7 @@ public void testGetMethodAsMap_whenMethodIsHNSWSQFP16_thenCreateCorrectIndexDesc KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); - Map map = Faiss.INSTANCE.getMethodAsMap(knnMethodContext); + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); @@ -134,7 +134,7 @@ public void testGetMethodAsMap_whenMethodIsIVFFlat_thenCreateCorrectIndexDescrip Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - Map map = Faiss.INSTANCE.getMethodAsMap(knnMethodContext); + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); @@ -164,7 +164,7 @@ public void testGetMethodAsMap_whenMethodIsIVFPQ_thenCreateCorrectIndexDescripti Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - Map map = Faiss.INSTANCE.getMethodAsMap(knnMethodContext); + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); @@ -192,7 +192,7 @@ public void testGetMethodAsMap_whenMethodIsIVFSQFP16_thenCreateCorrectIndexDescr Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - Map map = Faiss.INSTANCE.getMethodAsMap(knnMethodContext); + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index 12ae1d444..ae9ad7106 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -611,7 +611,7 @@ public void testTrain_whenConfigurationIsIVFSQFP16_thenSucceed() { .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - Map parameters = KNNEngine.FAISS.getMethodAsMap(knnMethodContext); + Map parameters = KNNEngine.FAISS.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); @@ -1131,7 +1131,7 @@ public void testTrain_whenConfigurationIsIVFFlat_thenSucceed() throws IOExceptio .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - Map parameters = KNNEngine.FAISS.getMethodAsMap(knnMethodContext); + Map parameters = KNNEngine.FAISS.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); @@ -1162,7 +1162,7 @@ public void testTrain_whenConfigurationIsIVFPQ_thenSucceed() throws IOException .endObject(); Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); - Map parameters = KNNEngine.FAISS.getMethodAsMap(knnMethodContext); + Map parameters = KNNEngine.FAISS.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); @@ -1190,7 +1190,7 @@ public void testTrain_whenConfigurationIsHNSWPQ_thenSucceed() throws IOException Map in = xContentBuilderToMap(xContentBuilder); KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); knnMethodContext.getMethodComponentContext().setIndexVersion(Version.CURRENT); - Map parameters = KNNEngine.FAISS.getMethodAsMap(knnMethodContext); + Map parameters = KNNEngine.FAISS.getKNNLibraryIndexingContext(knnMethodContext).getLibraryParameters(); byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); @@ -1237,7 +1237,11 @@ public void testCreateIndexFromTemplate() throws IOException { ) ); - String description = knnMethodContext.getKnnEngine().getMethodAsMap(knnMethodContext).get(INDEX_DESCRIPTION_PARAMETER).toString(); + String description = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext) + .getLibraryParameters() + .get(INDEX_DESCRIPTION_PARAMETER) + .toString(); assertEquals("IVF16,PQ16x8", description); Map parameters = ImmutableMap.of( @@ -1375,7 +1379,11 @@ private String createFaissIVFPQIndex(int ivfNlist, int pqM, int pqCodeSize, Spac ) ); - String description = knnMethodContext.getKnnEngine().getMethodAsMap(knnMethodContext).get(INDEX_DESCRIPTION_PARAMETER).toString(); + String description = knnMethodContext.getKnnEngine() + .getKNNLibraryIndexingContext(knnMethodContext) + .getLibraryParameters() + .get(INDEX_DESCRIPTION_PARAMETER) + .toString(); Map parameters = ImmutableMap.of( INDEX_DESCRIPTION_PARAMETER, description, diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java index 8dd4de81e..7fa0d3bca 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java @@ -12,15 +12,14 @@ package org.opensearch.knn.plugin.stats.suppliers; import org.opensearch.common.ValidationException; -import org.opensearch.knn.index.engine.EngineSpecificMethodContext; +import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; +import org.opensearch.knn.index.engine.KNNLibrarySearchContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNLibrary; import org.opensearch.knn.training.VectorSpaceInfo; import org.opensearch.test.OpenSearchTestCase; -import java.util.Map; - public class LibraryInitializedSupplierTests extends OpenSearchTestCase { public void testEngineInitialized() { @@ -55,7 +54,7 @@ public String getCompoundExtension() { } @Override - public EngineSpecificMethodContext getMethodContext(String methodName) { + public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) { return null; } @@ -95,7 +94,7 @@ public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension } @Override - public Map getMethodAsMap(KNNMethodContext knnMethodContext) { + public KNNLibraryIndexingContext getKNNLibraryIndexingContext(KNNMethodContext knnMethodContext) { return null; }