Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Sep 16, 2024
1 parent bfa734d commit 95b6b89
Show file tree
Hide file tree
Showing 2 changed files with 278 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -45,6 +46,7 @@
import static org.mockito.Mockito.mockConstruction;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

@RequiredArgsConstructor
Expand All @@ -68,6 +70,7 @@ public class NativeEngines990KnnVectorsWriterFlushTests extends OpenSearchTestCa
private final String description;
private final List<Map<Integer, float[]>> vectorsPerField;
private static final Integer BUILD_GRAPH_ALWAYS_THRESHOLD = 0;
private static final Integer BUILD_GRAPH_NEVER_THRESHOLD = -1;

@Override
public void setUp() throws Exception {
Expand Down Expand Up @@ -180,6 +183,161 @@ public void testFlush() {
}
}

public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled() throws IOException {
// Given
List<KNNVectorValues<float[]>> expectedVectorValues = new ArrayList<>();
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(
new ArrayList<>(vectorsPerField.get(i).values())
);
final KNNVectorValues<float[]> knnVectorValues = KNNVectorValuesFactory.getVectorValues(
VectorDataType.FLOAT,
randomVectorValues
);
expectedVectorValues.add(knnVectorValues);

});

final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter(
segmentWriteState,
flatVectorsWriter,
BUILD_GRAPH_NEVER_THRESHOLD
);

try (
MockedStatic<NativeEngineFieldVectorsWriter> fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class);
MockedStatic<KNNVectorValuesFactory> knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class);
MockedStatic<QuantizationService> quantizationServiceMockedStatic = mockStatic(QuantizationService.class);
MockedStatic<NativeIndexWriter> nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class);
MockedConstruction<KNN990QuantizationStateWriter> knn990QuantWriterMockedConstruction = mockConstruction(
KNN990QuantizationStateWriter.class
);
) {
quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService);
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final FieldInfo fieldInfo = fieldInfo(
i,
VectorEncoding.FLOAT32,
Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss")
);

NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i));
fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream))
.thenReturn(field);
try {
nativeEngineWriter.addField(fieldInfo);
} catch (Exception e) {
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null);
nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null))
.thenReturn(nativeIndexWriter);
});

doAnswer(answer -> {
Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion
return null;
}).when(nativeIndexWriter).flushIndex(any(), anyInt());

// When
nativeEngineWriter.flush(5, null);

// Then
verify(flatVectorsWriter).flush(5, null);
if (vectorsPerField.size() > 0) {
assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size());
}
verifyNoInteractions(nativeIndexWriter);
}
}

public void testFlush_whenThresholdIsEqualToNumberOfVectors_thenNativeIndexWriterIsCalled() throws IOException {
// Given
List<KNNVectorValues<float[]>> expectedVectorValues = new ArrayList<>();
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(
new ArrayList<>(vectorsPerField.get(i).values())
);
final KNNVectorValues<float[]> knnVectorValues = KNNVectorValuesFactory.getVectorValues(
VectorDataType.FLOAT,
randomVectorValues
);
expectedVectorValues.add(knnVectorValues);

});

final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter(
segmentWriteState,
flatVectorsWriter,
expectedVectorValues.size()
);

try (
MockedStatic<NativeEngineFieldVectorsWriter> fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class);
MockedStatic<KNNVectorValuesFactory> knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class);
MockedStatic<QuantizationService> quantizationServiceMockedStatic = mockStatic(QuantizationService.class);
MockedStatic<NativeIndexWriter> nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class);
MockedConstruction<KNN990QuantizationStateWriter> knn990QuantWriterMockedConstruction = mockConstruction(
KNN990QuantizationStateWriter.class
);
) {
quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService);
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final FieldInfo fieldInfo = fieldInfo(
i,
VectorEncoding.FLOAT32,
Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss")
);

NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i));
fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream))
.thenReturn(field);
try {
nativeEngineWriter.addField(fieldInfo);
} catch (Exception e) {
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null);
nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null))
.thenReturn(nativeIndexWriter);
});

doAnswer(answer -> {
Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion
return null;
}).when(nativeIndexWriter).flushIndex(any(), anyInt());

// When
nativeEngineWriter.flush(5, null);

// Then
verify(flatVectorsWriter).flush(5, null);
if (vectorsPerField.size() > 0) {
assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size());
assertTrue((long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0);
}
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
try {
verify(nativeIndexWriter).flushIndex(expectedVectorValues.get(i), vectorsPerField.get(i).size());
} catch (Exception e) {
throw new RuntimeException(e);
}
});
}
}

@SneakyThrows
public void testFlush_WithQuantization() {
// Given
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -74,6 +75,7 @@ public class NativeEngines990KnnVectorsWriterMergeTests extends OpenSearchTestCa
private final String description;
private final Map<Integer, float[]> mergedVectors;
private static final Integer BUILD_GRAPH_ALWAYS_THRESHOLD = 0;
private static final Integer BUILD_GRAPH_NEVER_THRESHOLD = -1;

@Override
public void setUp() throws Exception {
Expand Down Expand Up @@ -151,6 +153,124 @@ public void testMerge() {
}
}

public void testMerge_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled() throws IOException {
// Given
final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(
new ArrayList<>(mergedVectors.values())
);
final KNNVectorValues<float[]> knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues);
final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter(
segmentWriteState,
flatVectorsWriter,
BUILD_GRAPH_NEVER_THRESHOLD
);
try (
MockedStatic<NativeEngineFieldVectorsWriter> fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class);
MockedStatic<KNNVectorValuesFactory> knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class);
MockedStatic<QuantizationService> quantizationServiceMockedStatic = mockStatic(QuantizationService.class);
MockedStatic<NativeIndexWriter> nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class);
MockedStatic<KnnVectorsWriter.MergedVectorValues> mergedVectorValuesMockedStatic = mockStatic(
KnnVectorsWriter.MergedVectorValues.class
);
MockedConstruction<KNN990QuantizationStateWriter> knn990QuantWriterMockedConstruction = mockConstruction(
KNN990QuantizationStateWriter.class
);
) {
quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService);
final FieldInfo fieldInfo = fieldInfo(
0,
VectorEncoding.FLOAT32,
Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss")
);

NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors);
fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream))
.thenReturn(field);

mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState))
.thenReturn(floatVectorValues);
knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues))
.thenReturn(knnVectorValues);

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null);
nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null))
.thenReturn(nativeIndexWriter);
doAnswer(answer -> {
Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion
return null;
}).when(nativeIndexWriter).mergeIndex(any(), anyInt());

// When
nativeEngineWriter.mergeOneField(fieldInfo, mergeState);

// Then
verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState);
assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size());
verifyNoInteractions(nativeIndexWriter);
}
}

public void testMerge_whenThresholdIsEqualToNumberOfVectors_thenNativeIndexWriterIsCalled() throws IOException {
// Given
final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(
new ArrayList<>(mergedVectors.values())
);
final KNNVectorValues<float[]> knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues);
final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter(
segmentWriteState,
flatVectorsWriter,
mergedVectors.size()
);
try (
MockedStatic<NativeEngineFieldVectorsWriter> fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class);
MockedStatic<KNNVectorValuesFactory> knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class);
MockedStatic<QuantizationService> quantizationServiceMockedStatic = mockStatic(QuantizationService.class);
MockedStatic<NativeIndexWriter> nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class);
MockedStatic<KnnVectorsWriter.MergedVectorValues> mergedVectorValuesMockedStatic = mockStatic(
KnnVectorsWriter.MergedVectorValues.class
);
MockedConstruction<KNN990QuantizationStateWriter> knn990QuantWriterMockedConstruction = mockConstruction(
KNN990QuantizationStateWriter.class
);
) {
quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService);
final FieldInfo fieldInfo = fieldInfo(
0,
VectorEncoding.FLOAT32,
Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss")
);

NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, mergedVectors);
fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream))
.thenReturn(field);

mergedVectorValuesMockedStatic.when(() -> KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState))
.thenReturn(floatVectorValues);
knnVectorValuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, floatVectorValues))
.thenReturn(knnVectorValues);

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(null);
nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, null))
.thenReturn(nativeIndexWriter);
doAnswer(answer -> {
Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion
return null;
}).when(nativeIndexWriter).mergeIndex(any(), anyInt());

// When
nativeEngineWriter.mergeOneField(fieldInfo, mergeState);

// Then
verify(flatVectorsWriter).mergeOneField(fieldInfo, mergeState);
assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size());
if (!mergedVectors.isEmpty()) {
verify(nativeIndexWriter).mergeIndex(knnVectorValues, mergedVectors.size());
} else {
verifyNoInteractions(nativeIndexWriter);
}
}
}

@SneakyThrows
public void testMerge_WithQuantization() {
// Given
Expand Down

0 comments on commit 95b6b89

Please sign in to comment.