From 9063258239e14650da18931473a3aa803ab18daa Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Mon, 22 Jul 2024 15:06:02 -0500 Subject: [PATCH] Add more unit tests Signed-off-by: Naveen Tatikonda --- CHANGELOG.md | 1 - .../opensearch/knn/common/KNNConstants.java | 2 +- .../opensearch/knn/index/LuceneEngineIT.java | 17 ++- ...alarQuantizedVectorsFormatParamsTests.java | 110 ++++++++++++++++++ .../params/KNNVectorsFormatParamsTests.java | 56 +++++++++ 5 files changed, 174 insertions(+), 12 deletions(-) create mode 100644 src/test/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParamsTests.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParamsTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index a65076e323..bf0b46a33c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Add painless script support for hamming with binary vector data type [#1839](https://github.com/opensearch-project/k-NN/pull/1839) * Add binary format support with IVF method in Faiss Engine [#1784](https://github.com/opensearch-project/k-NN/pull/1784) * Add support for Lucene inbuilt Scalar Quantizer [#1848](https://github.com/opensearch-project/k-NN/pull/1848) ->>>>>>> dcbe6ad8 (Add Tests) ### Enhancements * Switch from byte stream to byte ref for serde [#1825](https://github.com/opensearch-project/k-NN/pull/1825) ### Bug Fixes diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 3002e426fc..56f9ffaf89 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -79,7 +79,7 @@ public class KNNConstants { public static final double MINIMUM_CONFIDENCE_INTERVAL = 0.9; public static final double MAXIMUM_CONFIDENCE_INTERVAL = 1.0; public static final String LUCENE_SQ_BITS = "bits"; - public static final Integer LUCENE_SQ_DEFAULT_BITS = 7; + public static final int LUCENE_SQ_DEFAULT_BITS = 7; // nmslib specific constants public static final String NMSLIB_NAME = "nmslib"; diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index 43bac52536..e7f38787d2 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -476,7 +476,9 @@ public void testRadiusSearch_usingScoreThreshold_withFilter_usingCosineMetrics_u } @SneakyThrows - public void testSQ_whenInvalidBits_thenThrowException() { + public void testSQ_withInvalidParams_thenThrowException() { + + // Use invalid number of bits for the bits param which throws an exception int bits = -1; expectThrows( ResponseException.class, @@ -488,11 +490,9 @@ public void testSQ_whenInvalidBits_thenThrowException() { MINIMUM_CONFIDENCE_INTERVAL ) ); - } - @SneakyThrows - public void testSQ_whenInvalidConfidenceInterval_thenThrowException() { - double confidenceInterval = 2.5; + // Use invalid value for confidence_interval param which throws an exception + double confidenceInterval = -2.5; expectThrows( ResponseException.class, () -> createKnnIndexMappingWithLuceneEngineAndSQEncoder( @@ -503,11 +503,9 @@ public void testSQ_whenInvalidConfidenceInterval_thenThrowException() { confidenceInterval ) ); - } - @SneakyThrows - public void testSQ_withByteVectorDataType_thenThrowException() { - Exception ex = expectThrows( + // Use "byte" data_type with sq encoder which throws an exception + expectThrows( ResponseException.class, () -> createKnnIndexMappingWithLuceneEngineAndSQEncoder( DIMENSION, @@ -517,7 +515,6 @@ public void testSQ_withByteVectorDataType_thenThrowException() { MINIMUM_CONFIDENCE_INTERVAL ) ); - assertTrue(ex.getMessage(), ex.getMessage().contains("data type does not support")); } @SneakyThrows diff --git a/src/test/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParamsTests.java b/src/test/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParamsTests.java new file mode 100644 index 0000000000..bcba8ebbd1 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/params/KNNScalarQuantizedVectorsFormatParamsTests.java @@ -0,0 +1,110 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.codec.params; + +import junit.framework.TestCase; +import org.opensearch.knn.index.MethodComponentContext; + +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_CONFIDENCE_INTERVAL; +import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_DEFAULT_BITS; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.MINIMUM_CONFIDENCE_INTERVAL; + +public class KNNScalarQuantizedVectorsFormatParamsTests extends TestCase { + private static final int DEFAULT_MAX_CONNECTIONS = 16; + private static final int DEFAULT_BEAM_WIDTH = 100; + + public void testInitParams_whenCalled_thenReturnDefaultParams() { + KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( + getDefaultParamsForConstructor(), + DEFAULT_MAX_CONNECTIONS, + DEFAULT_BEAM_WIDTH + ); + + assertEquals(DEFAULT_MAX_CONNECTIONS, knnScalarQuantizedVectorsFormatParams.getMaxConnections()); + assertEquals(DEFAULT_BEAM_WIDTH, knnScalarQuantizedVectorsFormatParams.getBeamWidth()); + assertNull(knnScalarQuantizedVectorsFormatParams.getConfidenceInterval()); + assertTrue(knnScalarQuantizedVectorsFormatParams.isCompressFlag()); + assertEquals(LUCENE_SQ_DEFAULT_BITS, knnScalarQuantizedVectorsFormatParams.getBits()); + } + + public void testInitParams_whenCalled_thenReturnParams() { + int m = 64; + int efConstruction = 128; + + Map encoderParams = new HashMap<>(); + encoderParams.put(LUCENE_SQ_CONFIDENCE_INTERVAL, MINIMUM_CONFIDENCE_INTERVAL); + MethodComponentContext encoderComponentContext = new MethodComponentContext(ENCODER_SQ, encoderParams); + + Map params = new HashMap<>(); + params.put(METHOD_ENCODER_PARAMETER, encoderComponentContext); + params.put(METHOD_PARAMETER_M, m); + params.put(METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction); + + KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( + params, + DEFAULT_MAX_CONNECTIONS, + DEFAULT_BEAM_WIDTH + ); + + assertEquals(m, knnScalarQuantizedVectorsFormatParams.getMaxConnections()); + assertEquals(efConstruction, knnScalarQuantizedVectorsFormatParams.getBeamWidth()); + assertEquals((float) MINIMUM_CONFIDENCE_INTERVAL, knnScalarQuantizedVectorsFormatParams.getConfidenceInterval()); + assertTrue(knnScalarQuantizedVectorsFormatParams.isCompressFlag()); + assertEquals(LUCENE_SQ_DEFAULT_BITS, knnScalarQuantizedVectorsFormatParams.getBits()); + } + + public void testValidate_whenCalled_thenReturnTrue() { + Map params = getDefaultParamsForConstructor(); + KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( + params, + DEFAULT_MAX_CONNECTIONS, + DEFAULT_BEAM_WIDTH + ); + assertTrue(knnScalarQuantizedVectorsFormatParams.validate(params)); + } + + public void testValidate_whenCalled_thenReturnFalse() { + KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams( + getDefaultParamsForConstructor(), + DEFAULT_MAX_CONNECTIONS, + DEFAULT_BEAM_WIDTH + ); + Map params = new HashMap<>(); + + // Return false if encoder value is null + params.put(METHOD_ENCODER_PARAMETER, null); + assertFalse(knnScalarQuantizedVectorsFormatParams.validate(params)); + + // Return false if encoder value is not an instance of MethodComponentContext + params.replace(METHOD_ENCODER_PARAMETER, "dummy string"); + assertFalse(knnScalarQuantizedVectorsFormatParams.validate(params)); + + // Return false if encoder name is not "sq" + MethodComponentContext encoderComponentContext = new MethodComponentContext("invalid encoder name", new HashMap<>()); + params.replace(METHOD_ENCODER_PARAMETER, encoderComponentContext); + assertFalse(knnScalarQuantizedVectorsFormatParams.validate(params)); + } + + private Map getDefaultParamsForConstructor() { + MethodComponentContext encoderComponentContext = new MethodComponentContext(ENCODER_SQ, new HashMap<>()); + Map params = new HashMap<>(); + params.put(METHOD_ENCODER_PARAMETER, encoderComponentContext); + return params; + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParamsTests.java b/src/test/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParamsTests.java new file mode 100644 index 0000000000..dca054046d --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/params/KNNVectorsFormatParamsTests.java @@ -0,0 +1,56 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.codec.params; + +import junit.framework.TestCase; + +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; + +public class KNNVectorsFormatParamsTests extends TestCase { + private static final int DEFAULT_MAX_CONNECTIONS = 16; + private static final int DEFAULT_BEAM_WIDTH = 100; + + public void testInitParams_whenCalled_thenReturnDefaultParams() { + KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams( + new HashMap<>(), + DEFAULT_MAX_CONNECTIONS, + DEFAULT_BEAM_WIDTH + ); + assertEquals(DEFAULT_MAX_CONNECTIONS, knnVectorsFormatParams.getMaxConnections()); + assertEquals(DEFAULT_BEAM_WIDTH, knnVectorsFormatParams.getBeamWidth()); + } + + public void testInitParams_whenCalled_thenReturnParams() { + int m = 64; + int efConstruction = 128; + Map params = new HashMap<>(); + params.put(METHOD_PARAMETER_M, m); + params.put(METHOD_PARAMETER_EF_CONSTRUCTION, efConstruction); + + KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, DEFAULT_MAX_CONNECTIONS, DEFAULT_BEAM_WIDTH); + assertEquals(m, knnVectorsFormatParams.getMaxConnections()); + assertEquals(efConstruction, knnVectorsFormatParams.getBeamWidth()); + } + + public void testValidate_whenCalled_thenReturnTrue() { + KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams( + new HashMap<>(), + DEFAULT_MAX_CONNECTIONS, + DEFAULT_BEAM_WIDTH + ); + assertTrue(knnVectorsFormatParams.validate(new HashMap<>())); + } +}