Skip to content

Commit

Permalink
Generalize lib interface to return context objects
Browse files Browse the repository at this point in the history
Generalizes the KNNLibrary to return an object for both search and
indexing so that the plugin can search/index against them. This will
help properly pass information that does not need to be sent to the JNI
for search and index builds.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Aug 2, 2024
1 parent 523c681 commit 00002b3
Show file tree
Hide file tree
Showing 23 changed files with 178 additions and 117 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824)
* Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913)
* Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920)
* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925)
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,18 @@ public abstract class AbstractKNNLibrary implements KNNLibrary {
protected final String version;

@Override
public EngineSpecificMethodContext getMethodContext(String methodName) {
public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) {
validateMethodExists(methodName);
KNNMethod method = methods.get(methodName);
return method.getEngineSpecificMethodContext();
return method.getKNNLibrarySearchContext();
}

@Override
public KNNLibraryIndexBuildContext getKNNLibraryIndexBuildContext(KNNMethodContext knnMethodContext) {
String method = knnMethodContext.getMethodComponentContext().getName();
validateMethodExists(method);
KNNMethod knnMethod = methods.get(method);
return knnMethod.getKNNLibraryIndexBuildContext(knnMethodContext);
}

@Override
Expand All @@ -51,14 +59,6 @@ public boolean isTrainingRequired(KNNMethodContext knnMethodContext) {
return methods.get(methodName).isTrainingRequired(knnMethodContext);
}

@Override
public Map<String, Object> getMethodAsMap(KNNMethodContext knnMethodContext) {
String method = knnMethodContext.getMethodComponentContext().getName();
validateMethodExists(method);
KNNMethod knnMethod = methods.get(method);
return knnMethod.getAsMap(knnMethodContext);
}

private void validateMethodExists(String methodName) {
KNNMethod method = methods.get(methodName);
if (method == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public abstract class AbstractKNNMethod implements KNNMethod {

protected final MethodComponent methodComponent;
protected final Set<SpaceType> spaces;
protected final EngineSpecificMethodContext engineSpecificMethodContext;
protected final KNNLibrarySearchContext knnLibrarySearchContext;

@Override
public boolean isSpaceTypeSupported(SpaceType space) {
Expand Down Expand Up @@ -106,14 +106,14 @@ public int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension
}

@Override
public Map<String, Object> getAsMap(KNNMethodContext knnMethodContext) {
public KNNLibraryIndexBuildContext getKNNLibraryIndexBuildContext(KNNMethodContext knnMethodContext) {
Map<String, Object> parameterMap = new HashMap<>(methodComponent.getAsMap(knnMethodContext.getMethodComponentContext()));
parameterMap.put(KNNConstants.SPACE_TYPE, knnMethodContext.getSpaceType().getValue());
return parameterMap;
return KNNLibraryIndexBuildContextImpl.builder().parameters(parameterMap).build();
}

@Override
public EngineSpecificMethodContext getEngineSpecificMethodContext() {
return engineSpecificMethodContext;
public KNNLibrarySearchContext getKNNLibrarySearchContext() {
return knnLibrarySearchContext;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
/**
* Default HNSW context for all engines. Have a different implementation if engine context differs.
*/
public final class DefaultHnswContext implements EngineSpecificMethodContext {
public final class DefaultHnswContext implements KNNLibrarySearchContext {

private final Map<String, Parameter<?>> supportedMethodParameters = ImmutableMap.<String, Parameter<?>>builder()
.put(MethodParameter.EF_SEARCH.getName(), new Parameter.IntegerParameter(MethodParameter.EF_SEARCH.getName(), null, value -> true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import java.util.Map;

public final class DefaultIVFContext implements EngineSpecificMethodContext {
public final class DefaultIVFContext implements KNNLibrarySearchContext {

private final Map<String, Parameter<?>> supportedMethodParameters = ImmutableMap.<String, Parameter<?>>builder()
.put(MethodParameter.NPROBE.getName(), new Parameter.IntegerParameter(MethodParameter.NPROBE.getName(), null, value -> true))
Expand Down

This file was deleted.

14 changes: 7 additions & 7 deletions src/main/java/org/opensearch/knn/index/engine/KNNEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,6 @@ public String getCompoundExtension() {
return knnLibrary.getCompoundExtension();
}

@Override
public EngineSpecificMethodContext getMethodContext(String methodName) {
return knnLibrary.getMethodContext(methodName);
}

@Override
public float score(float rawScore, SpaceType spaceType) {
return knnLibrary.score(rawScore, spaceType);
Expand Down Expand Up @@ -181,8 +176,13 @@ public boolean isTrainingRequired(KNNMethodContext knnMethodContext) {
}

@Override
public Map<String, Object> getMethodAsMap(KNNMethodContext knnMethodContext) {
return knnLibrary.getMethodAsMap(knnMethodContext);
public KNNLibraryIndexBuildContext getKNNLibraryIndexBuildContext(KNNMethodContext knnMethodContext) {
return knnLibrary.getKNNLibraryIndexBuildContext(knnMethodContext);
}

@Override
public KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName) {
return knnLibrary.getKNNLibrarySearchContext(methodName);
}

@Override
Expand Down
22 changes: 11 additions & 11 deletions src/main/java/org/opensearch/knn/index/engine/KNNLibrary.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import java.util.Collections;
import java.util.List;
import java.util.Map;

/**
* KNNLibrary is an interface that helps the plugin communicate with k-NN libraries
Expand Down Expand Up @@ -41,13 +40,6 @@ public interface KNNLibrary {
*/
String getCompoundExtension();

/**
* Gets metadata related to methods supported by the library
* @param methodName
* @return
*/
EngineSpecificMethodContext getMethodContext(String methodName);

/**
* Generate the Lucene score from the rawScore returned by the library. With k-NN, often times the library
* will return a score where the lower the score, the better the result. This is the opposite of how Lucene scores
Expand Down Expand Up @@ -116,12 +108,20 @@ public interface KNNLibrary {
int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension);

/**
* Generate method as map that can be used to configure the knn index from the jni
* Get the context from the library needed to build the index.
*
* @param knnMethodContext to generate parameter map from
* @param knnMethodContext to get build context for
* @return parameter map
*/
Map<String, Object> getMethodAsMap(KNNMethodContext knnMethodContext);
KNNLibraryIndexBuildContext getKNNLibraryIndexBuildContext(KNNMethodContext knnMethodContext);

/**
* Gets metadata related to methods supported by the library
*
* @param methodName name of method
* @return KNNLibrarySearchContext
*/
KNNLibrarySearchContext getKNNLibrarySearchContext(String methodName);

/**
* Getter for initialized
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import java.util.Collections;
import java.util.Map;

/**
* Context a library gives to build one of its indices
*/
public interface KNNLibraryIndexBuildContext {
/**
* Get map of parameters that get passed to the library to build the index
*
* @return Map of parameters
*/
Map<String, Object> getLibraryParameters();

KNNLibraryIndexBuildContext EMPTY = Collections::emptyMap;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import lombok.Builder;

import java.util.Map;

/**
* Simple implementation of {@link KNNLibraryIndexBuildContext}
*/
@Builder
public class KNNLibraryIndexBuildContextImpl implements KNNLibraryIndexBuildContext {

private Map<String, Object> parameters;

@Override
public Map<String, Object> getLibraryParameters() {
return parameters;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import org.opensearch.knn.index.engine.model.QueryContext;

import java.util.Collections;
import java.util.Map;

/**
* Holds the context needed to search a knn library.
*/
public interface KNNLibrarySearchContext {

/**
* Returns supported parameters for the library.
*
* @param ctx QueryContext
* @return parameters supported by the library
*/
Map<String, Parameter<?>> supportedMethodParameters(QueryContext ctx);

KNNLibrarySearchContext EMPTY = ctx -> Collections.emptyMap();
}
16 changes: 7 additions & 9 deletions src/main/java/org/opensearch/knn/index/engine/KNNMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.training.VectorSpaceInfo;

import java.util.Map;

/**
* KNNMethod defines the structure of a method supported by a particular k-NN library. It is used to validate
* the KNNMethodContext passed in by the user, where the KNNMethodContext provides the configuration that the user may
Expand Down Expand Up @@ -61,17 +59,17 @@ public interface KNNMethod {
int estimateOverheadInKB(KNNMethodContext knnMethodContext, int dimension);

/**
* Parse knnMethodContext into a map that the library can use to configure the index
* Parse knnMethodContext into context that the library can use to build the index
*
* @param knnMethodContext from which to generate map
* @return KNNMethod as a map
* @param knnMethodContext to generate the context for
* @return KNNLibraryIndexBuildContext
*/
Map<String, Object> getAsMap(KNNMethodContext knnMethodContext);
KNNLibraryIndexBuildContext getKNNLibraryIndexBuildContext(KNNMethodContext knnMethodContext);

/**
* Get the method context for a particular method
* Get the search context for a particular method
*
* @return EngineSpecificMethodContext for the method
* @return KNNLibrarySearchContext
*/
EngineSpecificMethodContext getEngineSpecificMethodContext();
KNNLibrarySearchContext getKNNLibrarySearchContext();
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
package org.opensearch.knn.index.engine.lucene;

import com.google.common.collect.ImmutableMap;
import org.opensearch.knn.index.engine.EngineSpecificMethodContext;
import org.opensearch.knn.index.engine.KNNLibrarySearchContext;
import org.opensearch.knn.index.engine.Parameter;
import org.opensearch.knn.index.engine.model.QueryContext;
import org.opensearch.knn.index.query.request.MethodParameter;

import java.util.Collections;
import java.util.Map;

public class LuceneHNSWContext implements EngineSpecificMethodContext {
public class LuceneHNSWContext implements KNNLibrarySearchContext {

private final Map<String, Parameter<?>> supportedMethodParameters = ImmutableMap.<String, Parameter<?>>builder()
.put(MethodParameter.EF_SEARCH.getName(), new Parameter.IntegerParameter(MethodParameter.EF_SEARCH.getName(), null, value -> true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.knn.index.engine.KNNEngine;

import java.io.IOException;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
Expand Down Expand Up @@ -58,10 +59,8 @@ public class MethodFieldMapper extends KNNVectorFieldMapper {
this.fieldType.putAttribute(KNN_ENGINE, knnEngine.getName());

try {
this.fieldType.putAttribute(
PARAMETERS,
XContentFactory.jsonBuilder().map(knnEngine.getMethodAsMap(knnMethodContext)).toString()
);
Map<String, Object> libParams = knnEngine.getKNNLibraryIndexBuildContext(knnMethodContext).getLibraryParameters();
this.fieldType.putAttribute(PARAMETERS, XContentFactory.jsonBuilder().map(libParams).toString());
} catch (IOException ioe) {
throw new RuntimeException(String.format("Unable to create KNNVectorFieldMapper: %s", ioe));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.opensearch.knn.index.VectorQueryType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser;
import org.opensearch.knn.index.engine.EngineSpecificMethodContext;
import org.opensearch.knn.index.engine.KNNLibrarySearchContext;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
Expand Down Expand Up @@ -373,7 +373,7 @@ protected Query doToQuery(QueryShardContext context) {

final String method = methodComponentContext != null ? methodComponentContext.getName() : null;
if (StringUtils.isNotBlank(method)) {
final EngineSpecificMethodContext engineSpecificMethodContext = knnEngine.getMethodContext(method);
final KNNLibrarySearchContext engineSpecificMethodContext = knnEngine.getKNNLibrarySearchContext(method);
QueryContext queryContext = new QueryContext(vectorQueryType);
ValidationException validationException = validateParameters(
engineSpecificMethodContext.supportedMethodParameters(queryContext),
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/org/opensearch/knn/training/TrainingJob.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ public void run() {
throw new RuntimeException("Unable to load training data into memory: allocation is already closed");
}
setVersionInKnnMethodContext();
Map<String, Object> trainParameters = model.getModelMetadata().getKnnEngine().getMethodAsMap(knnMethodContext);
Map<String, Object> trainParameters = model.getModelMetadata()
.getKnnEngine()
.getKNNLibraryIndexBuildContext(knnMethodContext)
.getLibraryParameters();
trainParameters.put(
KNNConstants.INDEX_THREAD_QTY,
KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)
Expand Down
4 changes: 2 additions & 2 deletions src/test/java/org/opensearch/knn/KNNTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.engine.EngineSpecificMethodContext;
import org.opensearch.knn.index.engine.KNNLibrarySearchContext;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.core.common.bytes.BytesReference;
Expand All @@ -32,7 +32,7 @@
*/
public class KNNTestCase extends OpenSearchTestCase {

protected static final EngineSpecificMethodContext EMPTY_ENGINE_SPECIFIC_CONTEXT = ctx -> Map.of();
protected static final KNNLibrarySearchContext EMPTY_ENGINE_SPECIFIC_CONTEXT = ctx -> Map.of();

@Mock
protected ClusterService clusterService;
Expand Down
Loading

0 comments on commit 00002b3

Please sign in to comment.