Skip to content

Commit

Permalink
Adding tests and cleaning up
Browse files Browse the repository at this point in the history
Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Sep 10, 2024
1 parent 1eebcf6 commit 51dce61
Show file tree
Hide file tree
Showing 19 changed files with 1,486 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ protected CompressionLevel resolveCompressionLevelFromMethodContext(
// If the context is null, the compression is not configured or the encoder is not defined, return not configured
// because the method context does not contain this info
if (isEncoderSpecified(resolvedKnnMethodContext) == false) {
return CompressionLevel.NOT_CONFIGURED;
return CompressionLevel.x1;
}
Encoder encoder = encoderMap.get(getEncoderName(resolvedKnnMethodContext));
if (encoder == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ private EngineResolver() {}
* @param requiresTraining whether config requires training
* @return {@link KNNEngine}
*/
public static KNNEngine resolveEngine(
public KNNEngine resolveEngine(
KNNMethodConfigContext knnMethodConfigContext,
KNNMethodContext knnMethodContext,
boolean requiresTraining
Expand All @@ -48,7 +48,7 @@ public static KNNEngine resolveEngine(
}

// For 1x, we need to default to faiss if mode is provided and use nmslib otherwise
if (compressionLevel == CompressionLevel.x1) {
if (CompressionLevel.isConfigured(compressionLevel) == false || compressionLevel == CompressionLevel.x1) {
return mode == Mode.ON_DISK ? KNNEngine.FAISS : KNNEngine.DEFAULT;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
* KNNMethodContext will contain the information necessary to produce a library index from an Opensearch mapping.
* It will encompass all parameters necessary to build the index.
*/
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@AllArgsConstructor(access = AccessLevel.PACKAGE)
@Getter
public class KNNMethodContext implements ToXContentFragment, Writeable {

Expand All @@ -46,6 +46,9 @@ public class KNNMethodContext implements ToXContentFragment, Writeable {
private SpaceType spaceType;
@NonNull
private final MethodComponentContext methodComponentContext;
// Currently, the KNNEngine member variable cannot be null and defaults during parsing to nmslib. However, in order
// to support disk based engine resolution, this value potentially needs to be updated. Thus, this value is used
// to determine if the variable can be overridden or not based on whether the user explicitly set the value during parsing
private boolean isEngineConfigured;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder {
protected ModelDao modelDao;
protected Version indexCreatedVersion;
@Setter
@Getter
private KNNMethodConfigContext knnMethodConfigContext;
@Setter
@Getter
Expand Down Expand Up @@ -496,7 +497,7 @@ private void resolveKNNMethodComponents(
}

// Based on config context, if the user does not set the engine, set it
KNNEngine resolvedKNNEngine = EngineResolver.resolveEngine(
KNNEngine resolvedKNNEngine = EngineResolver.INSTANCE.resolveEngine(
builder.knnMethodConfigContext,
builder.originalParameters.getResolvedKnnMethodContext(),
false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public TrainingModelRequest(
.mode(mode)
.build();

KNNEngine knnEngine = EngineResolver.resolveEngine(knnMethodConfigContext, knnMethodContext, true);
KNNEngine knnEngine = EngineResolver.INSTANCE.resolveEngine(knnMethodConfigContext, knnMethodContext, true);
ResolvedMethodContext resolvedMethodContext = knnEngine.resolveMethod(knnMethodContext, knnMethodConfigContext, true, spaceType);
this.knnMethodContext = resolvedMethodContext.getKnnMethodContext();
this.compressionLevel = resolvedMethodContext.getCompressionLevel();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;

import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;

public class AbstractMethodResolverTests extends KNNTestCase {

private final static String ENCODER_NAME = "test";
private final static CompressionLevel DEFAULT_COMPRESSION = CompressionLevel.x8;

private final static AbstractMethodResolver TEST_RESOLVER = new AbstractMethodResolver() {
@Override
public ResolvedMethodContext resolveMethod(
KNNMethodContext knnMethodContext,
KNNMethodConfigContext knnMethodConfigContext,
boolean shouldRequireTraining,
SpaceType spaceType
) {
return null;
}
};

private final static Encoder TEST_ENCODER = new Encoder() {

@Override
public MethodComponent getMethodComponent() {
return MethodComponent.Builder.builder(ENCODER_NAME).build();
}

@Override
public CompressionLevel calculateCompressionLevel(
MethodComponentContext encoderContext,
KNNMethodConfigContext knnMethodConfigContext
) {
return DEFAULT_COMPRESSION;
}
};

private final static Map<String, Encoder> ENCODER_MAP = Map.of(ENCODER_NAME, TEST_ENCODER);

public void testResolveCompressionLevelFromMethodContext() {
assertEquals(
CompressionLevel.NOT_CONFIGURED,
TEST_RESOLVER.resolveCompressionLevelFromMethodContext(
new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, MethodComponentContext.EMPTY),
KNNMethodConfigContext.builder().build(),
ENCODER_MAP
)
);
assertEquals(
DEFAULT_COMPRESSION,
TEST_RESOLVER.resolveCompressionLevelFromMethodContext(
new KNNMethodContext(
KNNEngine.DEFAULT,
SpaceType.DEFAULT,
new MethodComponentContext(
METHOD_HNSW,
Map.of(METHOD_ENCODER_PARAMETER, new MethodComponentContext(ENCODER_NAME, Map.of()))
)
),
KNNMethodConfigContext.builder().build(),
ENCODER_MAP
)
);
}

public void testIsEncoderSpecified() {
assertFalse(TEST_RESOLVER.isEncoderSpecified(null));
assertFalse(
TEST_RESOLVER.isEncoderSpecified(new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, MethodComponentContext.EMPTY))
);
assertFalse(
TEST_RESOLVER.isEncoderSpecified(
new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.DEFAULT, new MethodComponentContext(METHOD_HNSW, Map.of()))
)
);
assertTrue(
TEST_RESOLVER.isEncoderSpecified(
new KNNMethodContext(
KNNEngine.DEFAULT,
SpaceType.DEFAULT,
new MethodComponentContext(METHOD_HNSW, Map.of(METHOD_ENCODER_PARAMETER, "test"))
)
)
);
}

public void testShouldEncoderBeResolved() {
assertFalse(
TEST_RESOLVER.shouldEncoderBeResolved(
new KNNMethodContext(
KNNEngine.DEFAULT,
SpaceType.DEFAULT,
new MethodComponentContext(METHOD_HNSW, Map.of(METHOD_ENCODER_PARAMETER, "test"))
),
KNNMethodConfigContext.builder().build()
)
);
assertFalse(
TEST_RESOLVER.shouldEncoderBeResolved(null, KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x1).build())
);
assertFalse(
TEST_RESOLVER.shouldEncoderBeResolved(
null,
KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x1).mode(Mode.ON_DISK).build()
)
);
assertFalse(
TEST_RESOLVER.shouldEncoderBeResolved(
null,
KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.NOT_CONFIGURED).mode(Mode.IN_MEMORY).build()
)
);
assertFalse(
TEST_RESOLVER.shouldEncoderBeResolved(
null,
KNNMethodConfigContext.builder()
.compressionLevel(CompressionLevel.NOT_CONFIGURED)
.mode(Mode.ON_DISK)
.vectorDataType(VectorDataType.BINARY)
.build()
)
);
assertTrue(
TEST_RESOLVER.shouldEncoderBeResolved(
null,
KNNMethodConfigContext.builder()
.compressionLevel(CompressionLevel.NOT_CONFIGURED)
.mode(Mode.ON_DISK)
.vectorDataType(VectorDataType.FLOAT)
.build()
)
);
assertTrue(
TEST_RESOLVER.shouldEncoderBeResolved(
null,
KNNMethodConfigContext.builder()
.compressionLevel(CompressionLevel.x32)
.mode(Mode.ON_DISK)
.vectorDataType(VectorDataType.FLOAT)
.build()
)
);
}
}
152 changes: 152 additions & 0 deletions src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;

public class EngineResolverTests extends KNNTestCase {

private static final EngineResolver ENGINE_RESOLVER = EngineResolver.INSTANCE;

public void testResolveEngine_whenEngineSpecifiedInMethod_thenThatEngine() {
assertEquals(
KNNEngine.LUCENE,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().build(),
new KNNMethodContext(KNNEngine.LUCENE, SpaceType.DEFAULT, MethodComponentContext.EMPTY),
false
)
);
}

public void testResolveEngine_whenRequiresTraining_thenFaiss() {
assertEquals(KNNEngine.FAISS, ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().build(), null, true));
}

public void testResolveEngine_whenModeAndCompressionAreFalse_thenDefault() {
assertEquals(KNNEngine.DEFAULT, ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().build(), null, false));
assertEquals(
KNNEngine.DEFAULT,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().build(),
new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY, false),
false
)
);
}

public void testResolveEngine_whenModeSpecifiedAndCompressionIsNotSpecified_thenDefault() {
assertEquals(KNNEngine.DEFAULT, ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().build(), null, false));
assertEquals(
KNNEngine.DEFAULT,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().mode(Mode.IN_MEMORY).build(),
new KNNMethodContext(KNNEngine.DEFAULT, SpaceType.UNDEFINED, MethodComponentContext.EMPTY, false),
false
)
);
}

public void testResolveEngine_whenCompressionIs1x_thenEngineBasedOnMode() {
assertEquals(
KNNEngine.FAISS,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x1).build(),
null,
false
)
);
assertEquals(
KNNEngine.DEFAULT,
ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x1).build(), null, false)
);
}

public void testResolveEngine_whenCompressionIs4x_thenEngineIsLucene() {
assertEquals(
KNNEngine.LUCENE,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x4).build(),
null,
false
)
);
assertEquals(
KNNEngine.LUCENE,
ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x4).build(), null, false)
);
}

public void testResolveEngine_whenConfiguredForBQ_thenEngineIsFaiss() {
assertEquals(
KNNEngine.FAISS,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x2).build(),
null,
false
)
);
assertEquals(
KNNEngine.FAISS,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().mode(Mode.IN_MEMORY).compressionLevel(CompressionLevel.x2).build(),
null,
false
)
);
assertEquals(
KNNEngine.FAISS,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x8).build(),
null,
false
)
);
assertEquals(
KNNEngine.FAISS,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().mode(Mode.IN_MEMORY).compressionLevel(CompressionLevel.x8).build(),
null,
false
)
);
assertEquals(
KNNEngine.FAISS,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x16).build(),
null,
false
)
);
assertEquals(
KNNEngine.FAISS,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().mode(Mode.IN_MEMORY).compressionLevel(CompressionLevel.x16).build(),
null,
false
)
);
assertEquals(
KNNEngine.FAISS,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x32).build(),
null,
false
)
);
assertEquals(
KNNEngine.FAISS,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().mode(Mode.IN_MEMORY).compressionLevel(CompressionLevel.x32).build(),
null,
false
)
);
}
}
Loading

0 comments on commit 51dce61

Please sign in to comment.