From 7f00737b207a539d1f0b7b7920f2b14d8d092c6b Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 16 Jul 2024 12:53:55 -0700 Subject: [PATCH] Modify unit tests and refactor rounding to private method Signed-off-by: Ryan Bogan --- .../knn/index/codec/util/KNNCodecUtil.java | 23 +++++++++---------- .../KNN80DocValuesConsumerTests.java | 12 ++++++---- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java index 554d5cb39..c22131030 100644 --- a/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java +++ b/src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java @@ -97,27 +97,26 @@ public static long calculateArraySize(int numVectors, int vectorLength, Serializ // For more details on array memory usage in Java, visit https://www.javamex.com/tutorials/memory/array_memory_usage.shtml if (serializationMode == SerializationMode.ARRAY) { int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE; - if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { - vectorSize += (JAVA_ROUNDING_NUMBER - vectorSize % JAVA_ROUNDING_NUMBER); - } + vectorSize = roundVectorSize(vectorSize); int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE; - if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { - vectorsSize += (JAVA_ROUNDING_NUMBER - vectorsSize % JAVA_ROUNDING_NUMBER); - } + vectorsSize = roundVectorSize(vectorsSize); return vectorsSize; } else { int vectorSize = vectorLength * FLOAT_BYTE_SIZE; - if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { - vectorSize += (JAVA_ROUNDING_NUMBER - vectorSize % JAVA_ROUNDING_NUMBER); - } + vectorSize = roundVectorSize(vectorSize); int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE); - if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) { - vectorsSize += (JAVA_ROUNDING_NUMBER - vectorsSize % JAVA_ROUNDING_NUMBER); - } + vectorsSize = roundVectorSize(vectorsSize); return vectorsSize; } } + private static int roundVectorSize(int vectorSize) { + if (vectorSize % JAVA_ROUNDING_NUMBER != 0) { + return vectorSize + (JAVA_ROUNDING_NUMBER - vectorSize % JAVA_ROUNDING_NUMBER); + } + return vectorSize; + } + public static String buildEngineFileName(String segmentName, String latestBuildVersion, String fieldName, String extension) { return String.format("%s%s%s", buildEngineFilePrefix(segmentName), latestBuildVersion, buildEngineFileSuffix(fieldName, extension)); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index f7c9f3eb8..58927e9ac 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -192,6 +192,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + long initialMergeTotalSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue(); // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); @@ -212,7 +213,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + assertEquals(initialMergeTotalSize + 6800, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException { @@ -245,6 +246,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + long initialMergeTotalSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue(); // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); @@ -265,7 +267,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + assertEquals(initialMergeTotalSize + 6800, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException { @@ -306,6 +308,7 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + long initialMergeTotalSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue(); // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); @@ -326,7 +329,7 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + assertEquals(initialMergeTotalSize + 6800, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); } public void testAddKNNBinaryField_fromModel_faiss() throws IOException, ExecutionException, InterruptedException { @@ -401,6 +404,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue(); long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue(); + long initialMergeTotalSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue(); // Add documents to the field KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state); @@ -421,7 +425,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue()); assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue()); - assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); + assertEquals(initialMergeTotalSize + 6800, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue()); }