Skip to content

Commit

Permalink
Makes sure KNNVectorValues aren't recreated unnecessarily when quanti…
Browse files Browse the repository at this point in the history
…zation isn't needed (#2133) (#2140)

Signed-off-by: Tejas Shah <[email protected]>
(cherry picked from commit e33afa5)
  • Loading branch information
shatejas authored Sep 27, 2024
1 parent 19e181d commit ca6b03f
Show file tree
Hide file tree
Showing 5 changed files with 622 additions and 105 deletions.
1 change: 1 addition & 0 deletions release-notes/opensearch-knn.release-notes-2.17.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Compatible with OpenSearch 2.17.0
* Fix memory overflow caused by cache behavior [#2015](https://github.com/opensearch-project/k-NN/pull/2015)
* Use correct type for binary vector in ivf training [#2086](https://github.com/opensearch-project/k-NN/pull/2086)
* Switch MINGW32 to MINGW64 [#2090](https://github.com/opensearch-project/k-NN/pull/2090)
* Does not create additional KNNVectorValues in NativeEngines990KNNVectorWriter when quantization is not needed [#2133](https://github.com/opensearch-project/k-NN/pull/2133)
### Infrastructure
* Parallelize make to reduce build time [#2006] (https://github.com/opensearch-project/k-NN/pull/2006)
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,21 @@
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;

import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType;
import static org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory.getVectorValues;

/**
* A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines.
Expand All @@ -47,15 +48,11 @@
public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class);

private static final String FLUSH_OPERATION = "flush";
private static final String MERGE_OPERATION = "merge";

private final SegmentWriteState segmentWriteState;
private final FlatVectorsWriter flatVectorsWriter;
private KNN990QuantizationStateWriter quantizationStateWriter;
private final List<NativeEngineFieldVectorsWriter<?>> fields = new ArrayList<>();
private boolean finished;
private final QuantizationService quantizationService = QuantizationService.getInstance();

public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) {
this.segmentWriteState = segmentWriteState;
Expand Down Expand Up @@ -84,14 +81,27 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
flatVectorsWriter.flush(maxDoc, sortMap);

for (final NativeEngineFieldVectorsWriter<?> field : fields) {
trainAndIndex(
field.getFieldInfo(),
(vectorDataType, fieldInfo, fieldVectorsWriter) -> getKNNVectorValues(vectorDataType, fieldVectorsWriter),
NativeIndexWriter::flushIndex,
field,
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS,
FLUSH_OPERATION
);
final FieldInfo fieldInfo = field.getFieldInfo();
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
int totalLiveDocs = field.getVectors().size();
if (totalLiveDocs > 0) {
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getVectorValues(
vectorDataType,
field.getDocsWithField(),
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

StopWatch stopWatch = new StopWatch().start();
writer.flushIndex(knnVectorValues, totalLiveDocs);
long time_in_millis = stopWatch.stop().totalTime().millis();
KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
log.debug("Flush took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName());
} else {
log.debug("[Flush] No live docs for field {}", fieldInfo.getName());
}
}
}

Expand All @@ -100,15 +110,29 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
// This will ensure that we are merging the FlatIndex during force merge.
flatVectorsWriter.mergeOneField(fieldInfo, mergeState);

// For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs
trainAndIndex(
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getKNNVectorValuesForMerge(
vectorDataType,
fieldInfo,
this::getKNNVectorValuesForMerge,
NativeIndexWriter::mergeIndex,
mergeState,
KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS,
MERGE_OPERATION
mergeState
);
int totalLiveDocs = getLiveDocs(knnVectorValuesSupplier.get());
if (totalLiveDocs == 0) {
log.debug("[Merge] No live docs for field {}", fieldInfo.getName());
return;
}

final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
final NativeIndexWriter writer = NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState);
final KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();

StopWatch stopWatch = new StopWatch().start();

writer.mergeIndex(knnVectorValues, totalLiveDocs);

long time_in_millis = stopWatch.stop().totalTime().millis();
KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.incrementBy(time_in_millis);
log.debug("Merge took {} ms for vector field [{}]", time_in_millis, fieldInfo.getName());
}

/**
Expand Down Expand Up @@ -157,18 +181,6 @@ public long ramBytesUsed() {
.sum();
}

/**
* Retrieves the {@link KNNVectorValues} for a specific field based on the vector data type and field writer.
*
* @param vectorDataType The {@link VectorDataType} representing the type of vectors stored.
* @param field The {@link NativeEngineFieldVectorsWriter} representing the field from which to retrieve vectors.
* @param <T> The type of vectors being processed.
* @return The {@link KNNVectorValues} associated with the field.
*/
private <T> KNNVectorValues<T> getKNNVectorValues(final VectorDataType vectorDataType, final NativeEngineFieldVectorsWriter<?> field) {
return (KNNVectorValues<T>) KNNVectorValuesFactory.getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors());
}

/**
* Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type.
*
Expand All @@ -183,89 +195,41 @@ private <T> KNNVectorValues<T> getKNNVectorValuesForMerge(
final VectorDataType vectorDataType,
final FieldInfo fieldInfo,
final MergeState mergeState
) throws IOException {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
return (KNNVectorValues<T>) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats);
case BYTE:
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
return (KNNVectorValues<T>) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes);
default:
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
) {
try {
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32:
FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedFloats);
case BYTE:
ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
return getVectorValues(vectorDataType, mergedBytes);
default:
throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
}
} catch (final IOException e) {
log.error("Unable to merge vectors for field [{}]", fieldInfo.getName(), e);
throw new IllegalStateException("Unable to merge vectors for field [" + fieldInfo.getName() + "]", e);
}
}

/**
* Functional interface representing an operation that indexes the provided {@link KNNVectorValues}.
*
* @param <T> The type of vectors being processed.
*/
@FunctionalInterface
private interface IndexOperation<T> {
void buildAndWrite(NativeIndexWriter writer, KNNVectorValues<T> knnVectorValues, int totalLiveDocs) throws IOException;
}

/**
* Functional interface representing a method that retrieves {@link KNNVectorValues} based on
* the vector data type, field information, and the merge state.
*
* @param <DataType> The type of the data representing the vector (e.g., {@link VectorDataType}).
* @param <FieldInfo> The metadata about the field.
* @param <MergeState> The state of the merge operation.
* @param <Result> The result of the retrieval, typically {@link KNNVectorValues}.
*/
@FunctionalInterface
private interface VectorValuesRetriever<DataType, FieldInfo, MergeState, Result> {
Result apply(DataType vectorDataType, FieldInfo fieldInfo, MergeState mergeState) throws IOException;
}

/**
* Unified method for processing a field during either the indexing or merge operation. This method retrieves vector values
* based on the provided vector data type and applies the specified index operation, potentially including quantization if needed.
*
* @param fieldInfo The {@link FieldInfo} object containing metadata about the field.
* @param vectorValuesRetriever A functional interface that retrieves {@link KNNVectorValues} based on the vector data type,
* field information, and additional context (e.g., merge state or field writer).
* @param indexOperation A functional interface that performs the indexing operation using the retrieved
* {@link KNNVectorValues}.
* @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}).
* From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information
* @param <T> The type of vectors being processed.
* @param <C> The type of the context needed for retrieving the vector values.
* @throws IOException If an I/O error occurs during the processing.
*/
private <T, C> void trainAndIndex(
private QuantizationState train(
final FieldInfo fieldInfo,
final VectorValuesRetriever<VectorDataType, FieldInfo, C, KNNVectorValues<T>> vectorValuesRetriever,
final IndexOperation<T> indexOperation,
final C VectorProcessingContext,
final KNNGraphValue graphBuildTime,
final String operationName
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier,
final int totalLiveDocs
) throws IOException {
final VectorDataType vectorDataType = extractVectorDataType(fieldInfo);
KNNVectorValues<T> knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);

final QuantizationService quantizationService = QuantizationService.getInstance();
final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
QuantizationState quantizationState = null;
// Count the docIds
int totalLiveDocs = getLiveDocs(vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext));
if (quantizationParams != null && totalLiveDocs > 0) {
initQuantizationStateWriterIfNecessary();
KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
}
NativeIndexWriter writer = (quantizationParams != null)
? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)
: NativeIndexWriter.getWriter(fieldInfo, segmentWriteState);

knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);

StopWatch stopWatch = new StopWatch();
stopWatch.start();
indexOperation.buildAndWrite(writer, knnVectorValues, totalLiveDocs);
long time_in_millis = stopWatch.totalTime().millis();
graphBuildTime.incrementBy(time_in_millis);
log.warn("Graph build took " + time_in_millis + " ms for " + operationName);
return quantizationState;
}

/**
Expand Down
Loading

0 comments on commit ca6b03f

Please sign in to comment.