Skip to content

Commit

Permalink
Modify unit tests and refactor rounding to private method
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan committed Jul 16, 2024
1 parent 6bd10bd commit 7f00737
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
23 changes: 11 additions & 12 deletions src/main/java/org/opensearch/knn/index/codec/util/KNNCodecUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,26 @@ public static long calculateArraySize(int numVectors, int vectorLength, Serializ
// For more details on array memory usage in Java, visit https://www.javamex.com/tutorials/memory/array_memory_usage.shtml
if (serializationMode == SerializationMode.ARRAY) {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE + JAVA_ARRAY_HEADER_SIZE;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += (JAVA_ROUNDING_NUMBER - vectorSize % JAVA_ROUNDING_NUMBER);
}
vectorSize = roundVectorSize(vectorSize);
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE) + JAVA_ARRAY_HEADER_SIZE;
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += (JAVA_ROUNDING_NUMBER - vectorsSize % JAVA_ROUNDING_NUMBER);
}
vectorsSize = roundVectorSize(vectorsSize);
return vectorsSize;
} else {
int vectorSize = vectorLength * FLOAT_BYTE_SIZE;
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
vectorSize += (JAVA_ROUNDING_NUMBER - vectorSize % JAVA_ROUNDING_NUMBER);
}
vectorSize = roundVectorSize(vectorSize);
int vectorsSize = numVectors * (vectorSize + JAVA_REFERENCE_SIZE);
if (vectorsSize % JAVA_ROUNDING_NUMBER != 0) {
vectorsSize += (JAVA_ROUNDING_NUMBER - vectorsSize % JAVA_ROUNDING_NUMBER);
}
vectorsSize = roundVectorSize(vectorsSize);
return vectorsSize;
}
}

private static int roundVectorSize(int vectorSize) {
if (vectorSize % JAVA_ROUNDING_NUMBER != 0) {
return vectorSize + (JAVA_ROUNDING_NUMBER - vectorSize % JAVA_ROUNDING_NUMBER);
}
return vectorSize;
}

public static String buildEngineFileName(String segmentName, String latestBuildVersion, String fieldName, String extension) {
return String.format("%s%s%s", buildEngineFilePrefix(segmentName), latestBuildVersion, buildEngineFileSuffix(fieldName, extension));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException

long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue();
long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue();
long initialMergeTotalSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue();

// Add documents to the field
KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state);
Expand All @@ -212,7 +213,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException
assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue());
assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue());
assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue());
assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue());
assertEquals(initialMergeTotalSize + 6800, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue());
}

public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException {
Expand Down Expand Up @@ -245,6 +246,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException

long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue();
long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue();
long initialMergeTotalSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue();

// Add documents to the field
KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state);
Expand All @@ -265,7 +267,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException
assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue());
assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue());
assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue());
assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue());
assertEquals(initialMergeTotalSize + 6800, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue());
}

public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException {
Expand Down Expand Up @@ -306,6 +308,7 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException

long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue();
long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue();
long initialMergeTotalSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue();

// Add documents to the field
KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state);
Expand All @@ -326,7 +329,7 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException
assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue());
assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue());
assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue());
assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue());
assertEquals(initialMergeTotalSize + 6800, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue());
}

public void testAddKNNBinaryField_fromModel_faiss() throws IOException, ExecutionException, InterruptedException {
Expand Down Expand Up @@ -401,6 +404,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio

long initialRefreshOperations = KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue();
long initialMergeOperations = KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue();
long initialMergeTotalSize = KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue();

// Add documents to the field
KNN80DocValuesConsumer knn80DocValuesConsumer = new KNN80DocValuesConsumer(null, state);
Expand All @@ -421,7 +425,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio
assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue());
assertEquals(1 + initialMergeOperations, (long) KNNGraphValue.MERGE_TOTAL_OPERATIONS.getValue());
assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_DOCS.getValue());
assertNotEquals(0, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue());
assertEquals(initialMergeTotalSize + 6800, (long) KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.getValue());

}

Expand Down

0 comments on commit 7f00737

Please sign in to comment.