Skip to content

Commit

Permalink
Add more unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 22, 2024
1 parent c66a592 commit 5ff5db4
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public class KNNConstants {
public static final double MINIMUM_CONFIDENCE_INTERVAL = 0.9;
public static final double MAXIMUM_CONFIDENCE_INTERVAL = 1.0;
public static final String LUCENE_SQ_BITS = "bits";
public static final Integer LUCENE_SQ_DEFAULT_BITS = 7;
public static final int LUCENE_SQ_DEFAULT_BITS = 7;

// nmslib specific constants
public static final String NMSLIB_NAME = "nmslib";
Expand Down
17 changes: 7 additions & 10 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,9 @@ public void testRadiusSearch_usingScoreThreshold_withFilter_usingCosineMetrics_u
}

@SneakyThrows
public void testSQ_whenInvalidBits_thenThrowException() {
public void testSQ_withInvalidParams_thenThrowException() {

// Use invalid number of bits for the bits param which throws an exception
int bits = -1;
expectThrows(
ResponseException.class,
Expand All @@ -488,11 +490,9 @@ public void testSQ_whenInvalidBits_thenThrowException() {
MINIMUM_CONFIDENCE_INTERVAL
)
);
}

@SneakyThrows
public void testSQ_whenInvalidConfidenceInterval_thenThrowException() {
double confidenceInterval = 2.5;
// Use invalid value for confidence_interval param which throws an exception
double confidenceInterval = -2.5;
expectThrows(
ResponseException.class,
() -> createKnnIndexMappingWithLuceneEngineAndSQEncoder(
Expand All @@ -503,11 +503,9 @@ public void testSQ_whenInvalidConfidenceInterval_thenThrowException() {
confidenceInterval
)
);
}

@SneakyThrows
public void testSQ_withByteVectorDataType_thenThrowException() {
Exception ex = expectThrows(
// Use "byte" data_type with sq encoder which throws an exception
expectThrows(
ResponseException.class,
() -> createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
Expand All @@ -517,7 +515,6 @@ public void testSQ_withByteVectorDataType_thenThrowException() {
MINIMUM_CONFIDENCE_INTERVAL
)
);
assertTrue(ex.getMessage(), ex.getMessage().contains("data type does not support"));
}

@SneakyThrows
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec.params;

import junit.framework.TestCase;
import org.opensearch.knn.index.MethodComponentContext;

import java.util.HashMap;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;
import static org.opensearch.knn.common.KNNConstants.MINIMUM_CONFIDENCE_INTERVAL;

public class KNNScalarQuantizedVectorsFormatParamsTests extends TestCase {
private static final int DEFAULT_MAX_CONNECTIONS = 16;
private static final int DEFAULT_BEAM_WIDTH = 100;

public void testInitParams_whenCalled_thenReturnDefaultParams() {
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
getDefaultParamsForConstructor(),
DEFAULT_MAX_CONNECTIONS,
DEFAULT_BEAM_WIDTH
);

assertEquals(DEFAULT_MAX_CONNECTIONS, knnScalarQuantizedVectorsFormatParams.getMaxConnections());
assertEquals(DEFAULT_BEAM_WIDTH, knnScalarQuantizedVectorsFormatParams.getBeamWidth());
assertNull(knnScalarQuantizedVectorsFormatParams.getConfidenceInterval());
assertTrue(knnScalarQuantizedVectorsFormatParams.isCompressFlag());
assertEquals(LUCENE_SQ_DEFAULT_BITS, knnScalarQuantizedVectorsFormatParams.getBits());
}

public void testInitParams_whenCalled_thenReturnParams() {
int m = 64;
int efConstruction = 128;

Map<String, Object> encoderParams = new HashMap<>();
encoderParams.put(LUCENE_SQ_CONFIDENCE_INTERVAL, MINIMUM_CONFIDENCE_INTERVAL);
MethodComponentContext encoderComponentContext = new MethodComponentContext(ENCODER_SQ, encoderParams);

Map<String, Object> params = new HashMap<>();
params.put(METHOD_ENCODER_PARAMETER, encoderComponentContext);
params.put(METHOD_PARAMETER_M, m);
params.put(METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction);

KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
params,
DEFAULT_MAX_CONNECTIONS,
DEFAULT_BEAM_WIDTH
);

assertEquals(m, knnScalarQuantizedVectorsFormatParams.getMaxConnections());
assertEquals(efConstruction, knnScalarQuantizedVectorsFormatParams.getBeamWidth());
assertEquals((float) MINIMUM_CONFIDENCE_INTERVAL, knnScalarQuantizedVectorsFormatParams.getConfidenceInterval());
assertTrue(knnScalarQuantizedVectorsFormatParams.isCompressFlag());
assertEquals(LUCENE_SQ_DEFAULT_BITS, knnScalarQuantizedVectorsFormatParams.getBits());
}

public void testValidate_whenCalled_thenReturnTrue() {
Map<String, Object> params = getDefaultParamsForConstructor();
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
params,
DEFAULT_MAX_CONNECTIONS,
DEFAULT_BEAM_WIDTH
);
assertTrue(knnScalarQuantizedVectorsFormatParams.validate(params));
}

public void testValidate_whenCalled_thenReturnFalse() {
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
getDefaultParamsForConstructor(),
DEFAULT_MAX_CONNECTIONS,
DEFAULT_BEAM_WIDTH
);
Map<String, Object> params = new HashMap<>();

// Return false if encoder value is null
params.put(METHOD_ENCODER_PARAMETER, null);
assertFalse(knnScalarQuantizedVectorsFormatParams.validate(params));

// Return false if encoder value is not an instance of MethodComponentContext
params.replace(METHOD_ENCODER_PARAMETER, "dummy string");
assertFalse(knnScalarQuantizedVectorsFormatParams.validate(params));

// Return false if encoder name is not "sq"
MethodComponentContext encoderComponentContext = new MethodComponentContext("invalid encoder name", new HashMap<>());
params.replace(METHOD_ENCODER_PARAMETER, encoderComponentContext);
assertFalse(knnScalarQuantizedVectorsFormatParams.validate(params));
}

private Map<String, Object> getDefaultParamsForConstructor() {
MethodComponentContext encoderComponentContext = new MethodComponentContext(ENCODER_SQ, new HashMap<>());
Map<String, Object> params = new HashMap<>();
params.put(METHOD_ENCODER_PARAMETER, encoderComponentContext);
return params;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.codec.params;

import junit.framework.TestCase;

import java.util.HashMap;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M;

public class KNNVectorsFormatParamsTests extends TestCase {
private static final int DEFAULT_MAX_CONNECTIONS = 16;
private static final int DEFAULT_BEAM_WIDTH = 100;

public void testInitParams_whenCalled_thenReturnDefaultParams() {
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(
new HashMap<>(),
DEFAULT_MAX_CONNECTIONS,
DEFAULT_BEAM_WIDTH
);
assertEquals(DEFAULT_MAX_CONNECTIONS, knnVectorsFormatParams.getMaxConnections());
assertEquals(DEFAULT_BEAM_WIDTH, knnVectorsFormatParams.getBeamWidth());
}

public void testInitParams_whenCalled_thenReturnParams() {
int m = 64;
int efConstruction = 128;
Map<String, Object> params = new HashMap<>();
params.put(METHOD_PARAMETER_M, m);
params.put(METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction);

KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, DEFAULT_MAX_CONNECTIONS, DEFAULT_BEAM_WIDTH);
assertEquals(m, knnVectorsFormatParams.getMaxConnections());
assertEquals(efConstruction, knnVectorsFormatParams.getBeamWidth());
}

public void testValidate_whenCalled_thenReturnTrue() {
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(
new HashMap<>(),
DEFAULT_MAX_CONNECTIONS,
DEFAULT_BEAM_WIDTH
);
assertTrue(knnVectorsFormatParams.validate(new HashMap<>()));
}
}

0 comments on commit 5ff5db4

Please sign in to comment.