-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix mode/comp params so parameter overrides work (#2083)
PR adds capability to override parameters when specifying mode and compression. In order to do this, I add functionality for creating a deep copy of KNNMethodContext and MethodComponentContext so that we wouldnt overwrite user provided config. Then, re-arranged some of the parameter resolution logic. Signed-off-by: John Mazanec <[email protected]>
- Loading branch information
1 parent
524dbd0
commit 270ac6a
Showing
56 changed files
with
2,715 additions
and
479 deletions.
There are no files selected for viewing
185 changes: 185 additions & 0 deletions
185
src/main/java/org/opensearch/knn/index/engine/AbstractMethodResolver.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.engine; | ||
|
||
import org.opensearch.common.ValidationException; | ||
import org.opensearch.knn.index.SpaceType; | ||
import org.opensearch.knn.index.VectorDataType; | ||
import org.opensearch.knn.index.mapper.CompressionLevel; | ||
import org.opensearch.knn.index.mapper.Mode; | ||
|
||
import java.util.HashMap; | ||
import java.util.Locale; | ||
import java.util.Map; | ||
import java.util.Set; | ||
|
||
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; | ||
|
||
/** | ||
* Abstract {@link MethodResolver} with helpful utilitiy functions that can be shared across different | ||
* implementations | ||
*/ | ||
public abstract class AbstractMethodResolver implements MethodResolver { | ||
|
||
/** | ||
* Utility method to get the compression level from the context | ||
* | ||
* @param resolvedKnnMethodContext Resolved method context. Should have an encoder set in the params if available | ||
* @return {@link CompressionLevel} Compression level that is configured with the {@link KNNMethodContext} | ||
*/ | ||
protected CompressionLevel resolveCompressionLevelFromMethodContext( | ||
KNNMethodContext resolvedKnnMethodContext, | ||
KNNMethodConfigContext knnMethodConfigContext, | ||
Map<String, Encoder> encoderMap | ||
) { | ||
// If the context is null, the compression is not configured or the encoder is not defined, return not configured | ||
// because the method context does not contain this info | ||
if (isEncoderSpecified(resolvedKnnMethodContext) == false) { | ||
return CompressionLevel.x1; | ||
} | ||
Encoder encoder = encoderMap.get(getEncoderName(resolvedKnnMethodContext)); | ||
if (encoder == null) { | ||
return CompressionLevel.NOT_CONFIGURED; | ||
} | ||
return encoder.calculateCompressionLevel(getEncoderComponentContext(resolvedKnnMethodContext), knnMethodConfigContext); | ||
} | ||
|
||
protected void resolveMethodParams( | ||
MethodComponentContext methodComponentContext, | ||
KNNMethodConfigContext knnMethodConfigContext, | ||
MethodComponent methodComponent | ||
) { | ||
Map<String, Object> resolvedParams = MethodComponent.getParameterMapWithDefaultsAdded( | ||
methodComponentContext, | ||
methodComponent, | ||
knnMethodConfigContext | ||
); | ||
methodComponentContext.getParameters().putAll(resolvedParams); | ||
} | ||
|
||
protected KNNMethodContext initResolvedKNNMethodContext( | ||
KNNMethodContext originalMethodContext, | ||
KNNEngine knnEngine, | ||
SpaceType spaceType, | ||
String methodName | ||
) { | ||
if (originalMethodContext == null) { | ||
return new KNNMethodContext(knnEngine, spaceType, new MethodComponentContext(methodName, new HashMap<>())); | ||
} | ||
return new KNNMethodContext(originalMethodContext); | ||
} | ||
|
||
protected String getEncoderName(KNNMethodContext knnMethodContext) { | ||
if (isEncoderSpecified(knnMethodContext) == false) { | ||
return null; | ||
} | ||
|
||
MethodComponentContext methodComponentContext = getEncoderComponentContext(knnMethodContext); | ||
if (methodComponentContext == null) { | ||
return null; | ||
} | ||
|
||
return methodComponentContext.getName(); | ||
} | ||
|
||
protected MethodComponentContext getEncoderComponentContext(KNNMethodContext knnMethodContext) { | ||
if (isEncoderSpecified(knnMethodContext) == false) { | ||
return null; | ||
} | ||
|
||
return (MethodComponentContext) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_ENCODER_PARAMETER); | ||
} | ||
|
||
/** | ||
* Determine if the encoder parameter is specified | ||
* | ||
* @param knnMethodContext {@link KNNMethodContext} | ||
* @return true is the encoder is specified in the structure; false otherwise | ||
*/ | ||
protected boolean isEncoderSpecified(KNNMethodContext knnMethodContext) { | ||
return knnMethodContext != null | ||
&& knnMethodContext.getMethodComponentContext().getParameters() != null | ||
&& knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER); | ||
} | ||
|
||
protected boolean shouldEncoderBeResolved(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) { | ||
// The encoder should not be resolved if: | ||
// 1. The encoder is specified | ||
// 2. The compression is x1 | ||
// 3. The compression is not specified and the mode is not disk-based | ||
if (isEncoderSpecified(knnMethodContext)) { | ||
return false; | ||
} | ||
|
||
if (knnMethodConfigContext.getCompressionLevel() == CompressionLevel.x1) { | ||
return false; | ||
} | ||
|
||
if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel()) == false | ||
&& Mode.ON_DISK != knnMethodConfigContext.getMode()) { | ||
return false; | ||
} | ||
|
||
if (VectorDataType.FLOAT != knnMethodConfigContext.getVectorDataType()) { | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
protected ValidationException validateNotTrainingContext( | ||
boolean shouldRequireTraining, | ||
KNNEngine knnEngine, | ||
ValidationException validationException | ||
) { | ||
if (shouldRequireTraining) { | ||
validationException = validationException == null ? new ValidationException() : validationException; | ||
validationException.addValidationError( | ||
String.format(Locale.ROOT, "Cannot use \"%s\" engine from training context", knnEngine.getName()) | ||
); | ||
} | ||
|
||
return validationException; | ||
} | ||
|
||
protected ValidationException validateCompressionSupported( | ||
CompressionLevel compressionLevel, | ||
Set<CompressionLevel> supportedCompressionLevels, | ||
KNNEngine knnEngine, | ||
ValidationException validationException | ||
) { | ||
if (CompressionLevel.isConfigured(compressionLevel) && supportedCompressionLevels.contains(compressionLevel) == false) { | ||
validationException = validationException == null ? new ValidationException() : validationException; | ||
validationException.addValidationError( | ||
String.format(Locale.ROOT, "\"%s\" does not support \"%s\" compression", knnEngine.getName(), compressionLevel.getName()) | ||
); | ||
} | ||
return validationException; | ||
} | ||
|
||
protected ValidationException validateCompressionNotx1WhenOnDisk( | ||
KNNMethodConfigContext knnMethodConfigContext, | ||
ValidationException validationException | ||
) { | ||
if (knnMethodConfigContext.getCompressionLevel() == CompressionLevel.x1 && knnMethodConfigContext.getMode() == Mode.ON_DISK) { | ||
validationException = validationException == null ? new ValidationException() : validationException; | ||
validationException.addValidationError( | ||
String.format(Locale.ROOT, "Cannot specify \"x1\" compression level when using \"%s\" mode", Mode.ON_DISK.getName()) | ||
); | ||
} | ||
return validationException; | ||
} | ||
|
||
protected void validateCompressionConflicts(CompressionLevel originalCompressionLevel, CompressionLevel resolvedCompressionLevel) { | ||
if (CompressionLevel.isConfigured(originalCompressionLevel) | ||
&& CompressionLevel.isConfigured(resolvedCompressionLevel) | ||
&& resolvedCompressionLevel != originalCompressionLevel) { | ||
ValidationException validationException = new ValidationException(); | ||
validationException.addValidationError("Cannot specify an encoder that conflicts with the provided compression level"); | ||
throw validationException; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
62 changes: 62 additions & 0 deletions
62
src/main/java/org/opensearch/knn/index/engine/EngineResolver.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index.engine; | ||
|
||
import org.opensearch.knn.index.mapper.CompressionLevel; | ||
import org.opensearch.knn.index.mapper.Mode; | ||
|
||
/** | ||
* Figures out what {@link KNNEngine} to use based on configuration details | ||
*/ | ||
public final class EngineResolver { | ||
|
||
public static final EngineResolver INSTANCE = new EngineResolver(); | ||
|
||
private EngineResolver() {} | ||
|
||
/** | ||
* Based on the provided {@link Mode} and {@link CompressionLevel}, resolve to a {@link KNNEngine}. | ||
* | ||
* @param knnMethodConfigContext configuration context | ||
* @param knnMethodContext KNNMethodContext | ||
* @param requiresTraining whether config requires training | ||
* @return {@link KNNEngine} | ||
*/ | ||
public KNNEngine resolveEngine( | ||
KNNMethodConfigContext knnMethodConfigContext, | ||
KNNMethodContext knnMethodContext, | ||
boolean requiresTraining | ||
) { | ||
// User configuration gets precedence | ||
if (knnMethodContext != null && knnMethodContext.isEngineConfigured()) { | ||
return knnMethodContext.getKnnEngine(); | ||
} | ||
|
||
// Faiss is the only engine that supports training, so we default to faiss here for now | ||
if (requiresTraining) { | ||
return KNNEngine.FAISS; | ||
} | ||
|
||
Mode mode = knnMethodConfigContext.getMode(); | ||
CompressionLevel compressionLevel = knnMethodConfigContext.getCompressionLevel(); | ||
// If both mode and compression are not specified, we can just default | ||
if (Mode.isConfigured(mode) == false && CompressionLevel.isConfigured(compressionLevel) == false) { | ||
return KNNEngine.DEFAULT; | ||
} | ||
|
||
// For 1x, we need to default to faiss if mode is provided and use nmslib otherwise | ||
if (CompressionLevel.isConfigured(compressionLevel) == false || compressionLevel == CompressionLevel.x1) { | ||
return mode == Mode.ON_DISK ? KNNEngine.FAISS : KNNEngine.DEFAULT; | ||
} | ||
|
||
// Lucene is only engine that supports 4x - so we have to default to it here. | ||
if (compressionLevel == CompressionLevel.x4) { | ||
return KNNEngine.LUCENE; | ||
} | ||
|
||
return KNNEngine.FAISS; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.