Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mode/comp params so parameter overrides work #2083

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import org.opensearch.common.ValidationException;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;

import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;

/**
* Abstract {@link MethodResolver} with helpful utilitiy functions that can be shared across different
* implementations
*/
public abstract class AbstractMethodResolver implements MethodResolver {

/**
* Utility method to get the compression level from the context
*
* @param resolvedKnnMethodContext Resolved method context. Should have an encoder set in the params if available
* @return {@link CompressionLevel} Compression level that is configured with the {@link KNNMethodContext}
*/
protected CompressionLevel resolveCompressionLevelFromMethodContext(
KNNMethodContext resolvedKnnMethodContext,
KNNMethodConfigContext knnMethodConfigContext,
Map<String, Encoder> encoderMap
) {
// If the context is null, the compression is not configured or the encoder is not defined, return not configured
// because the method context does not contain this info
if (isEncoderSpecified(resolvedKnnMethodContext) == false) {
return CompressionLevel.x1;
}
Encoder encoder = encoderMap.get(getEncoderName(resolvedKnnMethodContext));
if (encoder == null) {
return CompressionLevel.NOT_CONFIGURED;
}
return encoder.calculateCompressionLevel(getEncoderComponentContext(resolvedKnnMethodContext), knnMethodConfigContext);
}

protected void resolveMethodParams(
MethodComponentContext methodComponentContext,
KNNMethodConfigContext knnMethodConfigContext,
MethodComponent methodComponent
) {
Map<String, Object> resolvedParams = MethodComponent.getParameterMapWithDefaultsAdded(
methodComponentContext,
methodComponent,
knnMethodConfigContext
);
methodComponentContext.getParameters().putAll(resolvedParams);
}

protected KNNMethodContext initResolvedKNNMethodContext(
KNNMethodContext originalMethodContext,
KNNEngine knnEngine,
SpaceType spaceType,
String methodName
) {
if (originalMethodContext == null) {
return new KNNMethodContext(knnEngine, spaceType, new MethodComponentContext(methodName, new HashMap<>()));
}
return new KNNMethodContext(originalMethodContext);
}

protected String getEncoderName(KNNMethodContext knnMethodContext) {
if (isEncoderSpecified(knnMethodContext) == false) {
return null;
}

MethodComponentContext methodComponentContext = getEncoderComponentContext(knnMethodContext);
if (methodComponentContext == null) {
return null;
}

return methodComponentContext.getName();
}

protected MethodComponentContext getEncoderComponentContext(KNNMethodContext knnMethodContext) {
if (isEncoderSpecified(knnMethodContext) == false) {
return null;
}

return (MethodComponentContext) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_ENCODER_PARAMETER);
}

/**
* Determine if the encoder parameter is specified
*
* @param knnMethodContext {@link KNNMethodContext}
* @return true is the encoder is specified in the structure; false otherwise
*/
protected boolean isEncoderSpecified(KNNMethodContext knnMethodContext) {
return knnMethodContext != null
&& knnMethodContext.getMethodComponentContext().getParameters() != null
&& knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER);
}

protected boolean shouldEncoderBeResolved(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
// The encoder should not be resolved if:
// 1. The encoder is specified
// 2. The compression is x1
// 3. The compression is not specified and the mode is not disk-based
if (isEncoderSpecified(knnMethodContext)) {
return false;
}

if (knnMethodConfigContext.getCompressionLevel() == CompressionLevel.x1) {
return false;
}

if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel()) == false
&& Mode.ON_DISK != knnMethodConfigContext.getMode()) {
return false;
}

if (VectorDataType.FLOAT != knnMethodConfigContext.getVectorDataType()) {
return false;
}

return true;
}

protected ValidationException validateNotTrainingContext(
boolean shouldRequireTraining,
KNNEngine knnEngine,
ValidationException validationException
) {
if (shouldRequireTraining) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationError(
String.format(Locale.ROOT, "Cannot use \"%s\" engine from training context", knnEngine.getName())
);
}

return validationException;
}

protected ValidationException validateCompressionSupported(
CompressionLevel compressionLevel,
Set<CompressionLevel> supportedCompressionLevels,
KNNEngine knnEngine,
ValidationException validationException
) {
if (CompressionLevel.isConfigured(compressionLevel) && supportedCompressionLevels.contains(compressionLevel) == false) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationError(
String.format(Locale.ROOT, "\"%s\" does not support \"%s\" compression", knnEngine.getName(), compressionLevel.getName())
);
}
return validationException;
}

protected ValidationException validateCompressionNotx1WhenOnDisk(
KNNMethodConfigContext knnMethodConfigContext,
ValidationException validationException
) {
if (knnMethodConfigContext.getCompressionLevel() == CompressionLevel.x1 && knnMethodConfigContext.getMode() == Mode.ON_DISK) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationError(
String.format(Locale.ROOT, "Cannot specify \"x1\" compression level when using \"%s\" mode", Mode.ON_DISK.getName())
);
}
return validationException;
}

protected void validateCompressionConflicts(CompressionLevel originalCompressionLevel, CompressionLevel resolvedCompressionLevel) {
if (CompressionLevel.isConfigured(originalCompressionLevel)
&& CompressionLevel.isConfigured(resolvedCompressionLevel)
&& resolvedCompressionLevel != originalCompressionLevel) {
ValidationException validationException = new ValidationException();
validationException.addValidationError("Cannot specify an encoder that conflicts with the provided compression level");
throw validationException;
}
}
}
12 changes: 12 additions & 0 deletions src/main/java/org/opensearch/knn/index/engine/Encoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.index.engine;

import org.opensearch.knn.index.mapper.CompressionLevel;

/**
* Interface representing an encoder. An encoder generally refers to a vector quantizer.
*/
Expand All @@ -24,4 +26,14 @@ default String getName() {
* @return Method component associated with the encoder
*/
MethodComponent getMethodComponent();

/**
* Calculate the compression level for the give params. Assume float32 vectors are used. All parameters should
* be resolved in the encoderContext passed in.
*
* @param encoderContext Context for the encoder to extract params from
* @return Compression level this encoder produces. If the encoder does not support this calculation yet, it will
* return {@link CompressionLevel#NOT_CONFIGURED}
*/
CompressionLevel calculateCompressionLevel(MethodComponentContext encoderContext, KNNMethodConfigContext knnMethodConfigContext);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;

/**
* Figures out what {@link KNNEngine} to use based on configuration details
*/
public final class EngineResolver {

public static final EngineResolver INSTANCE = new EngineResolver();

private EngineResolver() {}

/**
* Based on the provided {@link Mode} and {@link CompressionLevel}, resolve to a {@link KNNEngine}.
*
* @param knnMethodConfigContext configuration context
* @param knnMethodContext KNNMethodContext
* @param requiresTraining whether config requires training
* @return {@link KNNEngine}
*/
public KNNEngine resolveEngine(
KNNMethodConfigContext knnMethodConfigContext,
KNNMethodContext knnMethodContext,
boolean requiresTraining
) {
// User configuration gets precedence
if (knnMethodContext != null && knnMethodContext.isEngineConfigured()) {
return knnMethodContext.getKnnEngine();
}

// Faiss is the only engine that supports training, so we default to faiss here for now
if (requiresTraining) {
return KNNEngine.FAISS;
}

Mode mode = knnMethodConfigContext.getMode();
CompressionLevel compressionLevel = knnMethodConfigContext.getCompressionLevel();
// If both mode and compression are not specified, we can just default
if (Mode.isConfigured(mode) == false && CompressionLevel.isConfigured(compressionLevel) == false) {
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
return KNNEngine.DEFAULT;
}

// For 1x, we need to default to faiss if mode is provided and use nmslib otherwise
if (CompressionLevel.isConfigured(compressionLevel) == false || compressionLevel == CompressionLevel.x1) {
return mode == Mode.ON_DISK ? KNNEngine.FAISS : KNNEngine.DEFAULT;
}

// Lucene is only engine that supports 4x - so we have to default to it here.
if (compressionLevel == CompressionLevel.x4) {
return KNNEngine.LUCENE;
}

return KNNEngine.FAISS;
}
}
10 changes: 10 additions & 0 deletions src/main/java/org/opensearch/knn/index/engine/KNNEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,14 @@ public void setInitialized(Boolean isInitialized) {
public List<String> mmapFileExtensions() {
return knnLibrary.mmapFileExtensions();
}

@Override
public ResolvedMethodContext resolveMethod(
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext,
boolean shouldRequireTraining,
final SpaceType spaceType
) {
return knnLibrary.resolveMethod(knnMethodContext, knnMethodConfigContext, shouldRequireTraining, spaceType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
/**
* KNNLibrary is an interface that helps the plugin communicate with k-NN libraries
*/
public interface KNNLibrary {
public interface KNNLibrary extends MethodResolver {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to implement this? Would prefer punting it for now and rethink for later because no other methods are through interfaces in this interface

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean. Typically, the requests are routed to the correct library via the knn engine. So, by implementing on the engine and library, we can efficiently route it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its odd when KNNLibrary doesn't follow the pattern for other functions which are stand alone but can be broken down into interfaces


/**
* Gets the version of the library that is being used. In general, this can be used for ensuring compatibility of
Expand Down
Loading
Loading