Skip to content

Commit

Permalink
Fix the force merge with Quantization failures when a segment has del…
Browse files Browse the repository at this point in the history
…eted docs in it (opensearch-project#2046)

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v authored Sep 5, 2024
1 parent cb9ba71 commit da854c9
Show file tree
Hide file tree
Showing 14 changed files with 128 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
final VectorDataType vectorDataType = extractVectorDataType(field);
final KNNVectorValues<?> knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, valuesProducer.getBinary(field));

// For BDV it is fine to use knnVectorValues.totalLiveDocs() as we already run the full loop to calculate total
// live docs
if (isMerge) {
NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues);
NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues, (int) knnVectorValues.totalLiveDocs());
} else {
NativeIndexWriter.getWriter(field, state).flushIndex(knnVectorValues);
NativeIndexWriter.getWriter(field, state).flushIndex(knnVectorValues, (int) knnVectorValues.totalLiveDocs());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.common.StopWatch;
Expand Down Expand Up @@ -63,8 +64,6 @@ public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, Fla

/**
* Add new field for indexing.
* In Lucene, we use single file for all the vector fields so here we need to see how we are going to make things
* work.
* @param fieldInfo {@link FieldInfo}
*/
@Override
Expand Down Expand Up @@ -204,7 +203,7 @@ private <T> KNNVectorValues<T> getKNNVectorValuesForMerge(
*/
@FunctionalInterface
private interface IndexOperation<T> {
void buildAndWrite(NativeIndexWriter writer, KNNVectorValues<T> knnVectorValues) throws IOException;
void buildAndWrite(NativeIndexWriter writer, KNNVectorValues<T> knnVectorValues, int totalLiveDocs) throws IOException;
}

/**
Expand Down Expand Up @@ -248,9 +247,11 @@ private <T, C> void trainAndIndex(
KNNVectorValues<T> knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
QuantizationState quantizationState = null;
if (quantizationParams != null) {
// Count the docIds
int totalLiveDocs = getLiveDocs(vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext));
if (quantizationParams != null && totalLiveDocs > 0) {
initQuantizationStateWriterIfNecessary();
quantizationState = quantizationService.train(quantizationParams, knnVectorValues);
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
}
NativeIndexWriter writer = (quantizationParams != null)
Expand All @@ -261,12 +262,27 @@ private <T, C> void trainAndIndex(

StopWatch stopWatch = new StopWatch();
stopWatch.start();
indexOperation.buildAndWrite(writer, knnVectorValues);
indexOperation.buildAndWrite(writer, knnVectorValues, totalLiveDocs);
long time_in_millis = stopWatch.totalTime().millis();
graphBuildTime.incrementBy(time_in_millis);
log.warn("Graph build took " + time_in_millis + " ms for " + operationName);
}

/**
* The {@link KNNVectorValues} will be exhausted after this function run. So make sure that you are not sending the
* vectorsValues object which you plan to use later
*/
private int getLiveDocs(KNNVectorValues<?> vectorValues) throws IOException {
// Count all the live docs as there vectorValues.totalLiveDocs() just gives the cost for the FloatVectorValues,
// and doesn't tell the correct number of docs, if there are deleted docs in the segment. So we are counting
// the total live docs here.
int liveDocs = 0;
while (vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
liveDocs++;
}
return liveDocs;
}

private void initQuantizationStateWriterIfNecessary() throws IOException {
if (quantizationStateWriter == null) {
quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,17 @@ public static DefaultIndexBuildStrategy getInstance() {
* flushed and used to build the index. The index is then written to the specified path using JNI calls.</p>
*
* @param indexInfo The {@link BuildIndexParams} containing the parameters and configuration for building the index.
* @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed.
* @throws IOException If an I/O error occurs during the process of building and writing the index.
*/
public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException {
public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOException {
final KNNVectorValues<?> knnVectorValues = indexInfo.getVectorValues();
// Needed to make sure we don't get 0 dimensions while initializing index
iterateVectorValuesOnce(knnVectorValues);
IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo);

int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {
final List<Integer> transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs());
final List<Integer> transferredDocIds = new ArrayList<>(indexInfo.getTotalLiveDocs());

while (knnVectorValues.docId() != NO_MORE_DOCS) {
Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ public static MemOptimizedNativeIndexBuildStrategy getInstance() {
* @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed.
* @throws IOException If an I/O error occurs during the process of building and writing the index.
*/
public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException {
public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOException {
final KNNVectorValues<?> knnVectorValues = indexInfo.getVectorValues();
// Needed to make sure we don't get 0 dimensions while initializing index
iterateVectorValuesOnce(knnVectorValues);
KNNEngine engine = indexInfo.getKnnEngine();
Expand All @@ -62,7 +63,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
// Initialize the index
long indexMemoryAddress = AccessController.doPrivileged(
(PrivilegedAction<Long>) () -> JNIService.initIndex(
knnVectorValues.totalLiveDocs(),
indexInfo.getTotalLiveDocs(),
indexBuildSetup.getDimensions(),
indexParameters,
engine
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.knn.index.codec.nativeindex;

import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;

import java.io.IOException;

Expand All @@ -15,5 +14,5 @@
*/
public interface NativeIndexBuildStrategy {

void buildAndWriteIndex(BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException;
void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ public static NativeIndexWriter getWriter(
* @param knnVectorValues
* @throws IOException
*/
public void flushIndex(final KNNVectorValues<?> knnVectorValues) throws IOException {
public void flushIndex(final KNNVectorValues<?> knnVectorValues, int totalLiveDocs) throws IOException {
iterateVectorValuesOnce(knnVectorValues);
buildAndWriteIndex(knnVectorValues);
buildAndWriteIndex(knnVectorValues, totalLiveDocs);
recordRefreshStats();
}

Expand All @@ -117,7 +117,7 @@ public void flushIndex(final KNNVectorValues<?> knnVectorValues) throws IOExcept
* @param knnVectorValues
* @throws IOException
*/
public void mergeIndex(final KNNVectorValues<?> knnVectorValues) throws IOException {
public void mergeIndex(final KNNVectorValues<?> knnVectorValues, int totalLiveDocs) throws IOException {
iterateVectorValuesOnce(knnVectorValues);
if (knnVectorValues.docId() == NO_MORE_DOCS) {
// This is in place so we do not add metrics
Expand All @@ -126,13 +126,13 @@ public void mergeIndex(final KNNVectorValues<?> knnVectorValues) throws IOExcept
}

long bytesPerVector = knnVectorValues.bytesPerVector();
startMergeStats((int) knnVectorValues.totalLiveDocs(), bytesPerVector);
buildAndWriteIndex(knnVectorValues);
endMergeStats((int) knnVectorValues.totalLiveDocs(), bytesPerVector);
startMergeStats(totalLiveDocs, bytesPerVector);
buildAndWriteIndex(knnVectorValues, totalLiveDocs);
endMergeStats(totalLiveDocs, bytesPerVector);
}

private void buildAndWriteIndex(final KNNVectorValues<?> knnVectorValues) throws IOException {
if (knnVectorValues.totalLiveDocs() == 0) {
private void buildAndWriteIndex(final KNNVectorValues<?> knnVectorValues, int totalLiveDocs) throws IOException {
if (totalLiveDocs == 0) {
log.debug("No live docs for field " + fieldInfo.name);
return;
}
Expand All @@ -150,15 +150,21 @@ private void buildAndWriteIndex(final KNNVectorValues<?> knnVectorValues) throws
).toString();
state.directory.createOutput(engineFileName, state.context).close();

final BuildIndexParams nativeIndexParams = indexParams(fieldInfo, indexPath, knnEngine);
indexBuilder.buildAndWriteIndex(nativeIndexParams, knnVectorValues);
final BuildIndexParams nativeIndexParams = indexParams(fieldInfo, indexPath, knnEngine, knnVectorValues, totalLiveDocs);
indexBuilder.buildAndWriteIndex(nativeIndexParams);
writeFooter(indexPath, engineFileName, state);
}

// The logic for building parameters need to be cleaned up. There are various cases handled here
// Currently it falls under two categories - with model and without model. Without model is further divided based on vector data type
// TODO: Refactor this so its scalable. Possibly move it out of this class
private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNEngine knnEngine) throws IOException {
private BuildIndexParams indexParams(
FieldInfo fieldInfo,
String indexPath,
KNNEngine knnEngine,
KNNVectorValues<?> vectorValues,
int totalLiveDocs
) throws IOException {
final Map<String, Object> parameters;
VectorDataType vectorDataType;
if (quantizationState != null) {
Expand All @@ -180,6 +186,8 @@ private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNE
.knnEngine(knnEngine)
.indexPath(indexPath)
.quantizationState(quantizationState)
.vectorValues(vectorValues)
.totalLiveDocs(totalLiveDocs)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.util.Map;
Expand All @@ -29,4 +30,6 @@ public class BuildIndexParams {
*/
@Nullable
QuantizationState quantizationState;
KNNVectorValues<?> vectorValues;
int totalLiveDocs;
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ final class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {
*
* @param knnVectorValues the KNNVectorValues instance containing the vectors.
*/
KNNVectorQuantizationTrainingRequest(KNNVectorValues<T> knnVectorValues) {
super((int) knnVectorValues.totalLiveDocs());
KNNVectorQuantizationTrainingRequest(KNNVectorValues<T> knnVectorValues, long liveDocs) {
super((int) liveDocs);
this.knnVectorValues = knnVectorValues;
this.lastIndex = 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,15 @@ public static <T, R> QuantizationService<T, R> getInstance() {
* @return The {@link QuantizationState} containing the state of the trained quantizer.
* @throws IOException If an I/O error occurs during the training process.
*/
public QuantizationState train(final QuantizationParams quantizationParams, final KNNVectorValues<T> knnVectorValues)
throws IOException {
public QuantizationState train(
final QuantizationParams quantizationParams,
final KNNVectorValues<T> knnVectorValues,
final long liveDocs
) throws IOException {
Quantizer<T, R> quantizer = QuantizerFactory.getQuantizer(quantizationParams);

// Create the training request from the vector values
KNNVectorQuantizationTrainingRequest<T> trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues);
KNNVectorQuantizationTrainingRequest<T> trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs);

// Train the quantizer and return the quantization state
return quantizer.train(trainingRequest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,18 @@ public int bytesPerVector() {
}

/**
* Returns the total live docs for KNNVectorValues.
* Returns the total live docs for KNNVectorValues. This function is broken and doesn't always give the accurate
* live docs count when iterators are {@link FloatVectorValues}, {@link ByteVectorValues}. Avoid using this iterator,
* rather use a simple function like this:
* <pre class="prettyprint">
* int liveDocs = 0;
* while(vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
* liveDocs++;
* }
* </pre>
* @return long
*/
@Deprecated
public long totalLiveDocs() {
return vectorValuesIterator.liveDocs();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ public void testBuildAndWrite() {
.knnEngine(KNNEngine.NMSLIB)
.vectorDataType(VectorDataType.FLOAT)
.parameters(Map.of("index", "param"))
.vectorValues(knnVectorValues)
.totalLiveDocs((int) knnVectorValues.totalLiveDocs())
.build();

// When
DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues);
DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams);

// Then
mockedJNIService.verify(
Expand Down Expand Up @@ -166,10 +168,12 @@ public void testBuildAndWrite_withQuantization() {
.vectorDataType(VectorDataType.FLOAT)
.parameters(Map.of("index", "param"))
.quantizationState(quantizationState)
.vectorValues(knnVectorValues)
.totalLiveDocs((int) knnVectorValues.totalLiveDocs())
.build();

// When
MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues);
MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams);

// Then
mockedJNIService.verify(
Expand Down Expand Up @@ -250,10 +254,12 @@ public void testBuildAndWriteWithModel() {
.knnEngine(KNNEngine.NMSLIB)
.vectorDataType(VectorDataType.FLOAT)
.parameters(Map.of("model_id", "id", "model_blob", modelBlob))
.vectorValues(knnVectorValues)
.totalLiveDocs((int) knnVectorValues.totalLiveDocs())
.build();

// When
DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues);
DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams);

// Then
mockedJNIService.verify(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@ public void testBuildAndWrite() {
.knnEngine(KNNEngine.FAISS)
.vectorDataType(VectorDataType.FLOAT)
.parameters(Map.of("index", "param"))
.vectorValues(knnVectorValues)
.totalLiveDocs((int) knnVectorValues.totalLiveDocs())
.build();

// When
MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues);
MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams);

// Then
mockedJNIService.verify(
Expand Down Expand Up @@ -193,10 +195,12 @@ public void testBuildAndWrite_withQuantization() {
.vectorDataType(VectorDataType.FLOAT)
.parameters(Map.of("index", "param"))
.quantizationState(quantizationState)
.vectorValues(knnVectorValues)
.totalLiveDocs((int) knnVectorValues.totalLiveDocs())
.build();

// When
MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues);
MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams);

// Then
mockedJNIService.verify(
Expand Down
Loading

0 comments on commit da854c9

Please sign in to comment.