From ed4b530bad0bc6cad01bc21278ebf099fa489259 Mon Sep 17 00:00:00 2001 From: VIKASH TIWARI Date: Tue, 20 Aug 2024 21:25:09 -0700 Subject: [PATCH] Integration of Quantization Framework for Binary Quantization with Indexing Flow Signed-off-by: VIKASH TIWARI --- .../NativeEngines990KnnVectorsWriter.java | 133 ++++++++++++--- .../DefaultIndexBuildStrategy.java | 59 +++++-- .../MemOptimizedNativeIndexBuildStrategy.java | 59 +++++-- .../codec/nativeindex/NativeIndexWriter.java | 45 ++++- .../nativeindex/model/BuildIndexParams.java | 8 +- .../KNNVectorQuantizationTrainingRequest.java | 47 ++++++ .../QuantizationService.java | 123 ++++++++++++++ .../MultiBitScalarQuantizationState.java | 42 +++++ .../OneBitScalarQuantizationState.java | 33 ++++ .../quantizationState/QuantizationState.java | 30 ++++ ...NativeEngines990KnnVectorsFormatTests.java | 54 ++++-- .../DefaultIndexBuildStrategyTests.java | 116 +++++++++++++ ...ptimizedNativeIndexBuildStrategyTests.java | 118 +++++++++++++ .../QuantizationServiceTests.java | 159 ++++++++++++++++++ .../factory/QuantizerFactoryTests.java | 48 ++---- .../factory/QuantizerRegistryTests.java | 28 +-- .../QuantizationStateTests.java | 73 ++++++++ 17 files changed, 1068 insertions(+), 107 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java create mode 100644 src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java create mode 100644 src/test/java/org/opensearch/knn/index/quantizationService/QuantizationServiceTests.java diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 65736a63ef..92f9bb8312 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -24,10 +24,13 @@ import org.apache.lucene.index.Sorter; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; +import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.util.ArrayList; @@ -46,6 +49,7 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private final FlatVectorsWriter flatVectorsWriter; private final List> fields = new ArrayList<>(); private boolean finished; + private final QuantizationService quantizationService = QuantizationService.getInstance(); /** * Add new field for indexing. @@ -68,17 +72,14 @@ public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOExc */ @Override public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { - // simply write data in the flat file flatVectorsWriter.flush(maxDoc, sortMap); for (final NativeEngineFieldVectorsWriter field : fields) { - final VectorDataType vectorDataType = extractVectorDataType(field.getFieldInfo()); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( - vectorDataType, - field.getDocsWithField(), - field.getVectors() + trainAndIndex( + field.getFieldInfo(), + (vectorDataType, fieldInfo, fieldVectorsWriter) -> getKNNVectorValues(vectorDataType, fieldVectorsWriter), + NativeIndexWriter::flushIndex, + field ); - - NativeIndexWriter.getWriter(field.getFieldInfo(), segmentWriteState).flushIndex(knnVectorValues); } } @@ -86,24 +87,9 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState) throws IOException { // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); - // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs - final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); - final KNNVectorValues knnVectorValues; - switch (fieldInfo.getVectorEncoding()) { - case FLOAT32: - final FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats); - break; - case BYTE: - final ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes); - break; - default: - throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); - } + trainAndIndex(fieldInfo, this::getKNNVectorValuesForMerge, NativeIndexWriter::mergeIndex, mergeState); - NativeIndexWriter.getWriter(fieldInfo, segmentWriteState).mergeIndex(knnVectorValues); } /** @@ -146,4 +132,103 @@ public long ramBytesUsed() { .sum(); } + /** + * Retrieves the {@link KNNVectorValues} for a specific field based on the vector data type and field writer. + * + * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored. + * @param field The {@link NativeEngineFieldVectorsWriter} representing the field from which to retrieve vectors. + * @param The type of vectors being processed. + * @return The {@link KNNVectorValues} associated with the field. + */ + private KNNVectorValues getKNNVectorValues(final VectorDataType vectorDataType, final NativeEngineFieldVectorsWriter field) { + return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()); + } + + /** + * Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type. + * + * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored. + * @param fieldInfo The {@link FieldInfo} object containing metadata about the field. + * @param mergeState The {@link MergeState} representing the state of the merge operation. + * @param The type of vectors being processed. + * @return The {@link KNNVectorValues} associated with the field during the merge. + * @throws IOException If an I/O error occurs during the retrieval. + */ + private KNNVectorValues getKNNVectorValuesForMerge( + final VectorDataType vectorDataType, + final FieldInfo fieldInfo, + final MergeState mergeState + ) throws IOException { + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats); + case BYTE: + ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes); + default: + throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); + } + } + + /** + * Functional interface representing an operation that indexes the provided {@link KNNVectorValues}. + * + * @param The type of vectors being processed. + */ + @FunctionalInterface + private interface IndexOperation { + void buildAndWrite(NativeIndexWriter writer, KNNVectorValues knnVectorValues) throws IOException; + } + + /** + * Functional interface representing a method that retrieves {@link KNNVectorValues} based on + * the vector data type, field information, and the merge state. + * + * @param The type of the data representing the vector (e.g., {@link VectorDataType}). + * @param The metadata about the field. + * @param The state of the merge operation. + * @param The result of the retrieval, typically {@link KNNVectorValues}. + */ + @FunctionalInterface + private interface VectorValuesRetriever { + Result apply(DataType vectorDataType, FieldInfo fieldInfo, MergeState mergeState) throws IOException; + } + + /** + * Unified method for processing a field during either the indexing or merge operation. This method retrieves vector values + * based on the provided vector data type and applies the specified index operation, potentially including quantization if needed. + * + * @param fieldInfo The {@link FieldInfo} object containing metadata about the field. + * @param vectorValuesRetriever A functional interface that retrieves {@link KNNVectorValues} based on the vector data type, + * field information, and additional context (e.g., merge state or field writer). + * @param indexOperation A functional interface that performs the indexing operation using the retrieved + * {@link KNNVectorValues}. + * @param context The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}). + * @param The type of vectors being processed. + * @param The type of the context needed for retrieving the vector values. + * @throws IOException If an I/O error occurs during the processing. + */ + private void trainAndIndex( + final FieldInfo fieldInfo, + final VectorValuesRetriever> vectorValuesRetriever, + final IndexOperation indexOperation, + final C context + ) throws IOException { + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + KNNVectorValues knnVectorValuesForTraining = vectorValuesRetriever.apply(vectorDataType, fieldInfo, context); + KNNVectorValues knnVectorValuesForIndexing = vectorValuesRetriever.apply(vectorDataType, fieldInfo, context); + + QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); + QuantizationState quantizationState = null; + + if (quantizationParams != null) { + quantizationState = quantizationService.train(quantizationParams, knnVectorValuesForTraining); + } + NativeIndexWriter writer = (quantizationParams != null) + ? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState) + : NativeIndexWriter.getWriter(fieldInfo, segmentWriteState); + + indexOperation.buildAndWrite(writer, knnVectorValuesForIndexing); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index 5787ea76bf..01ebce17b1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -11,8 +11,11 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; +import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.security.AccessController; @@ -39,16 +42,50 @@ public static DefaultIndexBuildStrategy getInstance() { return INSTANCE; } + /** + * Builds and writes a k-NN index using the provided vector values and index parameters. This method handles both + * quantized and non-quantized vectors, transferring them off-heap before building the index using native JNI services. + * + *

The method first iterates over the vector values to calculate the necessary bytes per vector. If quantization is + * enabled, the vectors are quantized before being transferred off-heap. Once all vectors are transferred, they are + * flushed and used to build the index. The index is then written to the specified path using JNI calls.

+ * + * @param indexInfo The {@link BuildIndexParams} containing the parameters and configuration for building the index. + * @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed. + * @throws IOException If an I/O error occurs during the process of building and writing the index. + */ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { - iterateVectorValuesOnce(knnVectorValues); // to get bytesPerVector - int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / knnVectorValues.bytesPerVector()); + // Needed to make sure we don't get 0 dimensions while initializing index + iterateVectorValuesOnce(knnVectorValues); + QuantizationService quantizationHandler = QuantizationService.getInstance(); + QuantizationState quantizationState = indexInfo.getQuantizationState(); + QuantizationOutput quantizationOutput = null; + + int bytesPerVector; + int dimensions; + + // Handle quantization state if present + if (quantizationState != null) { + bytesPerVector = quantizationState.getBytesPerVector(); + dimensions = quantizationState.getDimensions(); + quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams()); + } else { + bytesPerVector = knnVectorValues.bytesPerVector(); + dimensions = knnVectorValues.dimension(); + } + + int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector); try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { + final List transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs()); - final List tranferredDocIds = new ArrayList<>(); while (knnVectorValues.docId() != NO_MORE_DOCS) { + Object vector = knnVectorValues.conditionalCloneVector(); + if (quantizationState != null && quantizationOutput != null) { + vector = quantizationHandler.quantize(quantizationState, vector, quantizationOutput); + } // append is true here so off heap memory buffer isn't overwritten - vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true); - tranferredDocIds.add(knnVectorValues.docId()); + vectorTransfer.transfer(vector, true); + transferredDocIds.add(knnVectorValues.docId()); knnVectorValues.nextDoc(); } vectorTransfer.flush(true); @@ -60,12 +97,12 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector if (params.containsKey(MODEL_ID)) { AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndexFromTemplate( - intListToArray(tranferredDocIds), + intListToArray(transferredDocIds), vectorAddress, - knnVectorValues.dimension(), + dimensions, indexInfo.getIndexPath(), (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER), - indexInfo.getParameters(), + params, indexInfo.getKnnEngine() ); return null; @@ -73,11 +110,11 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector } else { AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndex( - intListToArray(tranferredDocIds), + intListToArray(transferredDocIds), vectorAddress, - knnVectorValues.dimension(), + dimensions, indexInfo.getIndexPath(), - indexInfo.getParameters(), + params, indexInfo.getKnnEngine() ); return null; diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java index af80215b65..8aa75055d2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -11,8 +11,11 @@ import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.security.AccessController; @@ -40,11 +43,39 @@ public static MemOptimizedNativeIndexBuildStrategy getInstance() { return INSTANCE; } + /** + * Builds and writes a k-NN index using the provided vector values and index parameters. This method handles both + * quantized and non-quantized vectors, transferring them off-heap before building the index using native JNI services. + * + *

The method first iterates over the vector values to calculate the necessary bytes per vector. If quantization is + * enabled, the vectors are quantized before being transferred off-heap. Once all vectors are transferred, they are + * flushed and used to build the index. The index is then written to the specified path using JNI calls.

+ * + * @param indexInfo The {@link BuildIndexParams} containing the parameters and configuration for building the index. + * @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed. + * @throws IOException If an I/O error occurs during the process of building and writing the index. + */ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { - // Needed to make sure we dont get 0 dimensions while initializing index + // Needed to make sure we don't get 0 dimensions while initializing index iterateVectorValuesOnce(knnVectorValues); KNNEngine engine = indexInfo.getKnnEngine(); Map indexParameters = indexInfo.getParameters(); + QuantizationService quantizationHandler = QuantizationService.getInstance(); + QuantizationState quantizationState = indexInfo.getQuantizationState(); + QuantizationOutput quantizationOutput = null; + + int bytesPerVector; + int dimensions; + + // Handle quantization state if present + if (quantizationState != null) { + bytesPerVector = quantizationState.getBytesPerVector(); + dimensions = quantizationState.getDimensions(); + quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams()); + } else { + bytesPerVector = knnVectorValues.bytesPerVector(); + dimensions = knnVectorValues.dimension(); + } // Initialize the index long indexMemoryAddress = AccessController.doPrivileged( @@ -56,29 +87,35 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector ) ); - int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / knnVectorValues.bytesPerVector()); + int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector); try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { - final List tranferredDocIds = new ArrayList<>(transferLimit); + final List transferredDocIds = new ArrayList<>(transferLimit); + while (knnVectorValues.docId() != NO_MORE_DOCS) { // append is false to be able to reuse the memory location - boolean transferred = vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), false); - tranferredDocIds.add(knnVectorValues.docId()); + Object vector = knnVectorValues.conditionalCloneVector(); + if (quantizationState != null && quantizationOutput != null) { + vector = quantizationHandler.quantize(quantizationState, vector, quantizationOutput); + } + // append is false to be able to reuse the memory location + boolean transferred = vectorTransfer.transfer(vector, false); + transferredDocIds.add(knnVectorValues.docId()); if (transferred) { // Insert vectors long vectorAddress = vectorTransfer.getVectorAddress(); AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.insertToIndex( - intListToArray(tranferredDocIds), + intListToArray(transferredDocIds), vectorAddress, - knnVectorValues.dimension(), + dimensions, indexParameters, indexMemoryAddress, engine ); return null; }); - tranferredDocIds.clear(); + transferredDocIds.clear(); } knnVectorValues.nextDoc(); } @@ -89,16 +126,16 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector long vectorAddress = vectorTransfer.getVectorAddress(); AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.insertToIndex( - intListToArray(tranferredDocIds), + intListToArray(transferredDocIds), vectorAddress, - knnVectorValues.dimension(), + dimensions, indexParameters, indexMemoryAddress, engine ); return null; }); - tranferredDocIds.clear(); + transferredDocIds.clear(); } // Write vector diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index 61500371b1..4d69485ad5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -12,6 +12,7 @@ import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.FilterDirectory; +import org.opensearch.common.Nullable; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.DeprecationHandler; @@ -23,11 +24,13 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.io.OutputStream; @@ -60,6 +63,8 @@ public class NativeIndexWriter { private final SegmentWriteState state; private final FieldInfo fieldInfo; private final NativeIndexBuildStrategy indexBuilder; + @Nullable + private final QuantizationState quantizationState; /** * Gets the correct writer type from fieldInfo @@ -72,9 +77,37 @@ public static NativeIndexWriter getWriter(final FieldInfo fieldInfo, SegmentWrit boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID); boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine; if (iterative) { - return new NativeIndexWriter(state, fieldInfo, MemOptimizedNativeIndexBuildStrategy.getInstance()); + return new NativeIndexWriter(state, fieldInfo, MemOptimizedNativeIndexBuildStrategy.getInstance(), null); } - return new NativeIndexWriter(state, fieldInfo, DefaultIndexBuildStrategy.getInstance()); + return new NativeIndexWriter(state, fieldInfo, DefaultIndexBuildStrategy.getInstance(), null); + } + + /** + * Gets the correct writer type for the specified field, using a given QuantizationModel. + * + * This method returns a NativeIndexWriter instance that is tailored to the specific characteristics + * of the field described by the provided FieldInfo. It determines whether to use a template-based + * writer or an iterative approach based on the engine type and whether the field is associated with a template. + * + * If quantization is required, the QuantizationModel is passed to the writer to facilitate the quantization process. + * + * @param fieldInfo The FieldInfo object containing metadata about the field for which the writer is needed. + * @param state The SegmentWriteState representing the current segment's writing context. + * @param quantizationState The QuantizationState that contains quantization state required for quantization + * @return A NativeIndexWriter instance appropriate for the specified field, configured with or without quantization. + */ + public static NativeIndexWriter getWriter( + final FieldInfo fieldInfo, + final SegmentWriteState state, + final QuantizationState quantizationState + ) { + final KNNEngine knnEngine = extractKNNEngine(fieldInfo); + boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID); + boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine; + if (iterative) { + return new NativeIndexWriter(state, fieldInfo, MemOptimizedNativeIndexBuildStrategy.getInstance(), quantizationState); + } + return new NativeIndexWriter(state, fieldInfo, DefaultIndexBuildStrategy.getInstance(), quantizationState); } /** @@ -137,7 +170,12 @@ private void buildAndWriteIndex(final KNNVectorValues knnVectorValues) throws // TODO: Refactor this so its scalable. Possibly move it out of this class private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNEngine knnEngine) throws IOException { final Map parameters; - final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + VectorDataType vectorDataType; + if (quantizationState != null) { + vectorDataType = QuantizationService.getInstance().getVectorDataTypeForTransfer(fieldInfo); + } else { + vectorDataType = extractVectorDataType(fieldInfo); + } if (fieldInfo.attributes().containsKey(MODEL_ID)) { Model model = getModel(fieldInfo); parameters = getTemplateParameters(fieldInfo, model); @@ -151,6 +189,7 @@ private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNE .vectorDataType(vectorDataType) .knnEngine(knnEngine) .indexPath(indexPath) + .quantizationState(quantizationState) .build(); } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java index af43ff37e8..78674c64bf 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java @@ -8,8 +8,10 @@ import lombok.Builder; import lombok.ToString; import lombok.Value; +import org.opensearch.common.Nullable; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.util.Map; @@ -22,5 +24,9 @@ public class BuildIndexParams { String indexPath; VectorDataType vectorDataType; Map parameters; - // TODO: Add quantization state as parameter to build index + /** + * An optional quantization state that contains required information for quantization + */ + @Nullable + QuantizationState quantizationState; } diff --git a/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java b/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java new file mode 100644 index 0000000000..f885f6c4d8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.quantizationService; + +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; + +/** + * KNNVectorQuantizationTrainingRequest is a concrete implementation of the abstract TrainingRequest class. + * It provides a mechanism to retrieve float vectors from the KNNVectorValues by document ID. + */ +class KNNVectorQuantizationTrainingRequest extends TrainingRequest { + + private final KNNVectorValues knnVectorValues; + + /** + * Constructs a new QuantizationFloatVectorTrainingRequest. + * + * @param knnVectorValues the KNNVectorValues instance containing the vectors. + */ + KNNVectorQuantizationTrainingRequest(KNNVectorValues knnVectorValues) { + super((int) knnVectorValues.totalLiveDocs()); + this.knnVectorValues = knnVectorValues; + } + + /** + * Retrieves the float vector associated with the specified document ID. + * + * @param docId the document ID. + * @return the float vector corresponding to the specified document ID, or null if the docId is invalid. + */ + @Override + public T getVectorByDocId(int docId) { + try { + if (knnVectorValues.advance(docId) == docId) { + return knnVectorValues.getVector(); + } else { + throw new RuntimeException("Failed to advance to the correct docId: " + docId); + } + } catch (Exception e) { + throw new RuntimeException("Failed to retrieve vector for docId: " + docId, e); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java b/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java new file mode 100644 index 0000000000..f971f70f96 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java @@ -0,0 +1,123 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.quantizationService; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.lucene.index.FieldInfo; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.factory.QuantizerFactory; +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.quantizer.Quantizer; +import java.io.IOException; + +/** + * A singleton class responsible for handling the quantization process, including training a quantizer + * and applying quantization to vectors. This class is designed to be thread-safe. + * + * @param The type of the input vectors to be quantized. + * @param The type of the quantized output vectors. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class QuantizationService { + + /** + * The singleton instance of the {@link QuantizationService} class. + */ + private static final QuantizationService INSTANCE = new QuantizationService<>(); + + /** + * Returns the singleton instance of the {@link QuantizationService} class. + * + * @param The type of the input vectors to be quantized. + * @param The type of the quantized output vectors. + * @return The singleton instance of {@link QuantizationService}. + */ + public static QuantizationService getInstance() { + return (QuantizationService) INSTANCE; + } + + /** + * Trains a quantizer using the provided {@link KNNVectorValues} and returns the resulting + * {@link QuantizationState}. The quantizer is determined based on the given {@link QuantizationParams}. + * + * @param quantizationParams The {@link QuantizationParams} containing the parameters for quantization. + * @param knnVectorValues The {@link KNNVectorValues} representing the vector data to be used for training. + * @return The {@link QuantizationState} containing the state of the trained quantizer. + * @throws IOException If an I/O error occurs during the training process. + */ + public QuantizationState train(final QuantizationParams quantizationParams, final KNNVectorValues knnVectorValues) + throws IOException { + Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationParams); + + // Create the training request from the vector values + KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues); + + // Train the quantizer and return the quantization state + return quantizer.train(trainingRequest); + } + + /** + * Applies quantization to the given vector using the specified {@link QuantizationState} and + * {@link QuantizationOutput}. + * + * @param quantizationState The {@link QuantizationState} containing the state of the trained quantizer. + * @param vector The vector to be quantized. + * @param quantizationOutput The {@link QuantizationOutput} to store the quantized vector. + * @return The quantized vector as an object of type {@code R}. + */ + public R quantize(final QuantizationState quantizationState, final T vector, final QuantizationOutput quantizationOutput) { + Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationState.getQuantizationParams()); + quantizer.quantize(vector, quantizationState, quantizationOutput); + return quantizationOutput.getQuantizedVector(); + } + + /** + * Retrieves quantization parameters from the FieldInfo. + */ + public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) { + // TODO: Replace this with actual logic to extract quantization parameters from FieldInfo + if (fieldInfo.getAttribute("QuantizationConfig") != null) { + return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + } + return null; + } + + /** + * Retrieves the appropriate {@link VectorDataType} to be used during the transfer of vectors for indexing or merging. + * This method is intended to determine the correct vector data type based on the provided {@link FieldInfo}. + * + * @param fieldInfo The {@link FieldInfo} object containing metadata about the field for which the vector data type + * is being determined. + * @return The {@link VectorDataType} to be used during the vector transfer process + */ + public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) { + // TODO: Replace this with actual logic to extract quantization parameters from FieldInfo + return VectorDataType.BINARY; + } + + /** + * Creates the appropriate {@link QuantizationOutput} based on the given {@link QuantizationParams}. + * + * @param quantizationParams The {@link QuantizationParams} containing the parameters for quantization. + * @return The {@link QuantizationOutput} corresponding to the provided parameters. + * @throws IllegalArgumentException If the quantization parameters are unsupported. + */ + @SuppressWarnings("unchecked") + public QuantizationOutput createQuantizationOutput(final QuantizationParams quantizationParams) { + if (quantizationParams instanceof ScalarQuantizationParams) { + ScalarQuantizationParams scalarParams = (ScalarQuantizationParams) quantizationParams; + return (QuantizationOutput) new BinaryQuantizationOutput(scalarParams.getSqType().getId()); + } + throw new IllegalArgumentException("Unsupported quantization parameters: " + quantizationParams.getClass().getName()); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java index 2778a6cf49..ba54b60ab7 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java @@ -8,6 +8,7 @@ import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; +import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -124,4 +125,45 @@ public byte[] toByteArray() throws IOException { public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException { return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, MultiBitScalarQuantizationState::new); } + + /** + * Calculates and returns the number of bytes stored per vector after quantization. + * + * @return the number of bytes stored per vector. + */ + @Override + public int getBytesPerVector() { + // Check if thresholds are null or have invalid structure + if (thresholds == null || thresholds.length == 0 || thresholds[0] == null) { + throw new IllegalStateException("Error in getBytesStoredPerVector: The thresholds array is not initialized."); + } + + // Calculate the number of bytes required for multi-bit quantization + return thresholds.length * thresholds[0].length; + } + + @Override + public int getDimensions() { + // For multi-bit quantization, the dimension for indexing is the number of rows * columns in the thresholds array. + // Where number of column reprensents Dimesion of Original vector and number of rows equals to number of bits + return thresholds.length * thresholds[0].length; + } + + /** + * Calculates the memory usage of the MultiBitScalarQuantizationState object in bytes. + * This method computes the shallow size of the instance itself, the shallow size of the + * quantization parameters, and the memory usage of the 2D thresholds array. + * + * @return The estimated memory usage of the MultiBitScalarQuantizationState object in bytes. + */ + @Override + public long ramBytesUsed() { + long size = RamUsageEstimator.shallowSizeOfInstance(MultiBitScalarQuantizationState.class); + size += RamUsageEstimator.shallowSizeOf(quantizationParams); + size += RamUsageEstimator.shallowSizeOf(thresholds); // shallow size of the 2D array (array of references to rows) + for (float[] row : thresholds) { + size += RamUsageEstimator.sizeOf(row); // size of each row in the 2D array + } + return size; + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java index 9998b87e8c..9c4ff7460f 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -8,6 +8,7 @@ import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; +import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -107,4 +108,36 @@ public byte[] toByteArray() throws IOException { public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException { return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, OneBitScalarQuantizationState::new); } + + /** + * Calculates and returns the number of bytes stored per vector after quantization. + * + * @return the number of bytes stored per vector. + */ + @Override + public int getBytesPerVector() { + // Calculate the number of bytes required for one-bit quantization + return meanThresholds.length; + } + + @Override + public int getDimensions() { + // For one-bit quantization, the dimension for indexing is just the length of the thresholds array. + return meanThresholds.length; + } + + /** + * Calculates the memory usage of the OneBitScalarQuantizationState object in bytes. + * This method computes the shallow size of the instance itself, the shallow size of the + * quantization parameters, and the memory usage of the mean thresholds array. + * + * @return The estimated memory usage of the OneBitScalarQuantizationState object in bytes. + */ + @Override + public long ramBytesUsed() { + long size = RamUsageEstimator.shallowSizeOfInstance(OneBitScalarQuantizationState.class); + size += RamUsageEstimator.shallowSizeOf(quantizationParams); + size += RamUsageEstimator.sizeOf(meanThresholds); + return size; + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java index e32df8bc36..0a68d270e6 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java @@ -29,4 +29,34 @@ public interface QuantizationState extends Writeable { * @throws IOException if an I/O error occurs during serialization. */ byte[] toByteArray() throws IOException; + + /** + * Calculates the number of bytes stored per vector after quantization. + * This method can be overridden by implementing classes to provide the specific calculation. + * + * @return the number of bytes stored per vector. Default is 0. + */ + default int getBytesPerVector() { + return 0; + } + + /** + * Returns the effective dimension used for indexing after quantization. + * For one-bit quantization, this might correspond to the length of thresholds. + * For multi-bit quantization, this might correspond to rows * columns of the thresholds matrix. + * + * @return the effective dimension for indexing. Default is 0. + */ + default int getDimensions() { + return 0; + } + + /** + * Estimates the memory usage of the quantization state in bytes. + * + * @return the memory usage in bytes. + */ + default long ramBytesUsed() { + return 0; + }; } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index 3810d46fd2..a3fe527187 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -61,6 +61,7 @@ public class NativeEngines990KnnVectorsFormatTests extends KNNTestCase { private static final String FLAT_VECTOR_FILE_EXT = ".vec"; private static final String HNSW_FILE_EXT = ".hnsw"; private static final String FLOAT_VECTOR_FIELD = "float_field"; + private static final String FLOAT_VECTOR_FIELD_BINARY = "float_field_binary"; private static final String BYTE_VECTOR_FIELD = "byte_field"; private Directory dir; private RandomIndexWriter indexWriter; @@ -99,14 +100,14 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc float[] floatVector = { 1.0f, 3.0f, 4.0f }; byte[] byteVector = { 6, 14 }; - addFieldToIndex( - new KnnFloatVectorField(FLOAT_VECTOR_FIELD, floatVector, createVectorField(3, VectorEncoding.FLOAT32, VectorDataType.FLOAT)), - indexWriter - ); - addFieldToIndex( - new KnnByteVectorField(BYTE_VECTOR_FIELD, byteVector, createVectorField(2, VectorEncoding.BYTE, VectorDataType.BINARY)), - indexWriter - ); + FieldType fieldTypeForFloat = createVectorField(3, VectorEncoding.FLOAT32, VectorDataType.FLOAT); + fieldTypeForFloat.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"HNSW16,Flat\", \"spaceType\": \"l2\"}"); + fieldTypeForFloat.freeze(); + addFieldToIndex(new KnnFloatVectorField(FLOAT_VECTOR_FIELD, floatVector, fieldTypeForFloat), indexWriter); + FieldType fieldTypeForByte = createVectorField(2, VectorEncoding.BYTE, VectorDataType.BINARY); + fieldTypeForByte.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"HNSW16,Flat\", \"spaceType\": \"l2\"}"); + fieldTypeForByte.freeze(); + addFieldToIndex(new KnnByteVectorField(BYTE_VECTOR_FIELD, byteVector, fieldTypeForByte), indexWriter); final IndexReader indexReader = indexWriter.getReader(); // ensuring segments are created indexWriter.flush(); @@ -157,6 +158,41 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc indexReader.close(); } + @SneakyThrows + public void testNativeEngineVectorFormat_whenBinaryQuantizationApplied_thenSuccess() { + setup(); + float[] floatVectorForBinaryQuantization = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }; + FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}"); + fieldTypeForBinaryQuantization.putAttribute("QuantizationConfig", "{ \"type\": \"Binary\" }"); + + addFieldToIndex( + new KnnFloatVectorField(FLOAT_VECTOR_FIELD_BINARY, floatVectorForBinaryQuantization, fieldTypeForBinaryQuantization), + indexWriter + ); + + final IndexReader indexReader = indexWriter.getReader(); + // ensuring segments are created + indexWriter.flush(); + indexWriter.commit(); + indexWriter.close(); + IndexSearcher searcher = new IndexSearcher(indexReader); + final LeafReader leafReader = searcher.getLeafContexts().get(0).reader(); + SegmentReader segmentReader = Lucene.segmentReader(leafReader); + if (segmentReader.getSegmentInfo().info.getUseCompoundFile() == false) { + final List vecfiles = getFilesFromSegment(dir, FLAT_VECTOR_FILE_EXT); + // 2 .vec files will be created as we are using per field vectors format. + assertEquals(1, vecfiles.size()); + } + + final FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD_BINARY); + floatVectorValues.nextDoc(); + assertArrayEquals(floatVectorForBinaryQuantization, floatVectorValues.vectorValue(), 0.0f); + assertEquals(1, floatVectorValues.size()); + assertEquals(8, floatVectorValues.dimension()); + indexReader.close(); + } + private List getFilesFromSegment(Directory dir, String fileFormat) throws IOException { return Arrays.stream(dir.listAll()).filter(x -> x.contains(fileFormat)).collect(Collectors.toList()); } @@ -203,13 +239,11 @@ private FieldType createVectorField(int dimension, VectorEncoding vectorEncoding nativeVectorField.putAttribute(KNNConstants.HNSW_ALGO_M, "32"); nativeVectorField.putAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512"); nativeVectorField.putAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); - nativeVectorField.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"HNSW16,Flat\", \"spaceType\": \"l2\"}"); nativeVectorField.setVectorAttributes( dimension, vectorEncoding, SpaceType.L2.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() ); - nativeVectorField.freeze(); return nativeVectorField; } } diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java index 34a333471e..0b5a06dfc8 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java @@ -17,11 +17,14 @@ import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; import java.util.List; @@ -102,6 +105,119 @@ public void testBuildAndWrite() { } } + @SneakyThrows + public void testBuildAndWrite_withQuantization() { + // Given + ArgumentCaptor vectorAddressCaptor = ArgumentCaptor.forClass(Long.class); + ArgumentCaptor vectorTransferCapture = ArgumentCaptor.forClass(Object.class); + + List vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }, new float[] { 3, 4 }); + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + vectorValues + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class); + MockedStatic mockedJNIService = mockStatic(JNIService.class); + MockedStatic mockedOffHeapVectorTransferFactory = mockStatic(OffHeapVectorTransferFactory.class); + MockedStatic mockedQuantizationIntegration = mockStatic(QuantizationService.class) + ) { + + // Limits transfer to 2 vectors + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + mockedJNIService.when(() -> JNIService.initIndex(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); + + OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + .thenReturn(offHeapVectorTransfer); + + QuantizationService quantizationService = mock(QuantizationService.class); + mockedQuantizationIntegration.when(QuantizationService::getInstance).thenReturn(quantizationService); + + QuantizationState quantizationState = mock(QuantizationState.class); + ArgumentCaptor vectorCaptor = ArgumentCaptor.forClass(float[].class); + // New: Create QuantizationOutput and mock the quantization process + QuantizationOutput quantizationOutput = mock(QuantizationOutput.class); + when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 1, 2 }); + when(quantizationService.createQuantizationOutput(eq(quantizationState.getQuantizationParams()))).thenReturn( + quantizationOutput + ); + + // Quantize the vector with the quantization output + when(quantizationService.quantize(eq(quantizationState), vectorCaptor.capture(), eq(quantizationOutput))).thenAnswer( + invocation -> { + quantizationOutput.getQuantizedVector(); + return quantizationOutput.getQuantizedVector(); + } + ); + when(quantizationState.getDimensions()).thenReturn(2); + when(quantizationState.getBytesPerVector()).thenReturn(8); + + when(offHeapVectorTransfer.transfer(vectorTransferCapture.capture(), eq(false))).thenReturn(false) + .thenReturn(true) + .thenReturn(false); + when(offHeapVectorTransfer.flush(false)).thenReturn(true); + when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); + + BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .indexPath("indexPath") + .knnEngine(KNNEngine.FAISS) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of("index", "param")) + .quantizationState(quantizationState) + .build(); + + // When + MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + + // Then + mockedJNIService.verify( + () -> JNIService.initIndex( + knnVectorValues.totalLiveDocs(), + knnVectorValues.dimension(), + Map.of("index", "param"), + KNNEngine.FAISS + ) + ); + + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 0, 1 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + // For the flush + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 2 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + mockedJNIService.verify( + () -> JNIService.writeIndex(eq("indexPath"), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) + ); + assertEquals(200L, vectorAddressCaptor.getValue().longValue()); + assertEquals(vectorAddressCaptor.getValue().longValue(), vectorAddressCaptor.getAllValues().get(0).longValue()); + verify(offHeapVectorTransfer, times(0)).reset(); + + for (Object vector : vectorTransferCapture.getAllValues()) { + // Assert that the vector is in byte[] format due to quantization + assertTrue(vector instanceof byte[]); + } + } + } + @SneakyThrows public void testBuildAndWriteWithModel() { // Given diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java index 2ecfe92596..3bfec4104e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java @@ -16,10 +16,13 @@ import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; import java.util.List; @@ -126,4 +129,119 @@ public void testBuildAndWrite() { } } } + + @SneakyThrows + public void testBuildAndWrite_withQuantization() { + // Given + ArgumentCaptor vectorAddressCaptor = ArgumentCaptor.forClass(Long.class); + ArgumentCaptor vectorTransferCapture = ArgumentCaptor.forClass(Object.class); + + List vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }, new float[] { 3, 4 }); + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + vectorValues + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic mockedKNNSettings = Mockito.mockStatic(KNNSettings.class); + MockedStatic mockedJNIService = Mockito.mockStatic(JNIService.class); + MockedStatic mockedOffHeapVectorTransferFactory = Mockito.mockStatic( + OffHeapVectorTransferFactory.class + ); + MockedStatic mockedQuantizationIntegration = Mockito.mockStatic(QuantizationService.class) + ) { + + // Limits transfer to 2 vectors + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + mockedJNIService.when(() -> JNIService.initIndex(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); + + OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + .thenReturn(offHeapVectorTransfer); + + QuantizationService quantizationService = mock(QuantizationService.class); + mockedQuantizationIntegration.when(QuantizationService::getInstance).thenReturn(quantizationService); + + QuantizationState quantizationState = mock(QuantizationState.class); + ArgumentCaptor vectorCaptor = ArgumentCaptor.forClass(float[].class); + // New: Create QuantizationOutput and mock the quantization process + QuantizationOutput quantizationOutput = mock(QuantizationOutput.class); + when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 1, 2 }); + when(quantizationService.createQuantizationOutput(eq(quantizationState.getQuantizationParams()))).thenReturn( + quantizationOutput + ); + + // Quantize the vector with the quantization output + when(quantizationService.quantize(eq(quantizationState), vectorCaptor.capture(), eq(quantizationOutput))).thenAnswer( + invocation -> { + quantizationOutput.getQuantizedVector(); + return quantizationOutput.getQuantizedVector(); + } + ); + when(quantizationState.getDimensions()).thenReturn(2); + when(quantizationState.getBytesPerVector()).thenReturn(8); + + when(offHeapVectorTransfer.transfer(vectorTransferCapture.capture(), eq(false))).thenReturn(false) + .thenReturn(true) + .thenReturn(false); + when(offHeapVectorTransfer.flush(false)).thenReturn(true); + when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); + + BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .indexPath("indexPath") + .knnEngine(KNNEngine.FAISS) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of("index", "param")) + .quantizationState(quantizationState) + .build(); + + // When + MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + + // Then + mockedJNIService.verify( + () -> JNIService.initIndex( + knnVectorValues.totalLiveDocs(), + knnVectorValues.dimension(), + Map.of("index", "param"), + KNNEngine.FAISS + ) + ); + + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 0, 1 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + // For the flush + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 2 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + mockedJNIService.verify( + () -> JNIService.writeIndex(eq("indexPath"), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) + ); + assertEquals(200L, vectorAddressCaptor.getValue().longValue()); + assertEquals(vectorAddressCaptor.getValue().longValue(), vectorAddressCaptor.getAllValues().get(0).longValue()); + verify(offHeapVectorTransfer, times(0)).reset(); + + for (Object vector : vectorTransferCapture.getAllValues()) { + // Assert that the vector is in byte[] format due to quantization + assertTrue(vector instanceof byte[]); + } + } + } } diff --git a/src/test/java/org/opensearch/knn/index/quantizationService/QuantizationServiceTests.java b/src/test/java/org/opensearch/knn/index/quantizationService/QuantizationServiceTests.java new file mode 100644 index 0000000000..886dbeabc8 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/quantizationService/QuantizationServiceTests.java @@ -0,0 +1,159 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.quantizationService; + +import org.opensearch.knn.KNNTestCase; +import org.junit.Before; + +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import java.io.IOException; +import java.util.List; + +public class QuantizationServiceTests extends KNNTestCase { + private QuantizationService quantizationService; + private KNNVectorValues knnVectorValues; + + @Before + public void setUp() throws Exception { + super.setUp(); + quantizationService = QuantizationService.getInstance(); + + // Predefined float vectors for testing + List floatVectors = List.of( + new float[] { 1.0f, 2.0f, 3.0f }, + new float[] { 4.0f, 5.0f, 6.0f }, + new float[] { 7.0f, 8.0f, 9.0f } + ); + + // Use the predefined vectors to create KNNVectorValues + knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + new TestVectorValues.PreDefinedFloatVectorValues(floatVectors) + ); + } + + public void testTrain_oneBitQuantizer_success() throws IOException { + ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); + + assertTrue(quantizationState instanceof OneBitScalarQuantizationState); + OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState; + + // Validate the mean thresholds obtained from the training + float[] thresholds = oneBitState.getMeanThresholds(); + assertNotNull("Thresholds should not be null", thresholds); + assertEquals("Thresholds array length should match the dimension", 3, thresholds.length); + + // Example expected thresholds based on the provided vectors + assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, thresholds, 0.1f); + } + + public void testTrain_twoBitQuantizer_success() throws IOException { + ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues); + + assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); + MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; + + // Validate the thresholds obtained from the training + float[][] thresholds = multiBitState.getThresholds(); + assertNotNull("Thresholds should not be null", thresholds); + assertEquals("Number of bits should match the number of rows", 2, thresholds.length); + assertEquals("Thresholds array length should match the dimension", 3, thresholds[0].length); + + // // Example expected thresholds for two-bit quantization + float[][] expectedThresholds = { + { 3.1835034f, 4.1835036f, 5.1835036f }, // First bit level + { 4.816497f, 5.816497f, 6.816497f } // Second bit level + }; + for (int i = 0; i < thresholds.length; i++) { + assertArrayEquals(expectedThresholds[i], thresholds[i], 0.1f); + } + } + + public void testTrain_fourBitQuantizer_success() throws IOException { + ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues); + + assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); + MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; + + // Validate the thresholds obtained from the training + float[][] thresholds = multiBitState.getThresholds(); + assertNotNull("Thresholds should not be null", thresholds); + assertEquals("Number of bits should match the number of rows", 4, thresholds.length); + assertEquals("Thresholds array length should match the dimension", 3, thresholds[0].length); + + // // Example expected thresholds for four-bit quantization + float[][] expectedThresholds = { + { 2.530306f, 3.530306f, 4.530306f }, // First bit level + { 3.510102f, 4.5101023f, 5.5101023f }, // Second bit level + { 4.489898f, 5.489898f, 6.489898f }, // Third bit level + { 5.469694f, 6.469694f, 7.469694f } // Fourth bit level + }; + for (int i = 0; i < thresholds.length; i++) { + assertArrayEquals(expectedThresholds[i], thresholds[i], 0.1f); + } + } + + public void testQuantize_oneBitQuantizer_success() throws IOException { + ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); + + QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); + + byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 1.0f, 2.0f, 3.0f }, quantizationOutput); + + assertNotNull("Quantized vector should not be null", quantizedVector); + + // Expected quantized vector values for one-bit quantization (packed bits) + byte[] expectedQuantizedVector = new byte[] { 0 }; // 00000000 (all bits are 0) + assertArrayEquals(expectedQuantizedVector, quantizedVector); + } + + public void testQuantize_twoBitQuantizer_success() throws IOException { + ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues); + QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams); + byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput); + + assertNotNull("Quantized vector should not be null", quantizedVector); + + // Expected quantized vector values for two-bit quantization (packed bits) + byte[] expectedQuantizedVector = new byte[] { (byte) 0b11100000 }; + assertArrayEquals(expectedQuantizedVector, quantizedVector); + } + + public void testQuantize_fourBitQuantizer_success() throws IOException { + ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues); + QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams); + + byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput); + + assertNotNull("Quantized vector should not be null", quantizedVector); + + // Expected quantized vector values for four-bit quantization (packed bits) + byte[] expectedQuantizedVector = new byte[] { (byte) 0xFF, (byte) 0xF0 }; + assertArrayEquals(expectedQuantizedVector, quantizedVector); + } + + public void testQuantize_whenInvalidInput_thenThrows() throws IOException { + ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); + QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); + assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput)); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java index b95123e212..f6974aea21 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java @@ -5,7 +5,6 @@ package org.opensearch.knn.quantization.factory; -import org.junit.Before; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; @@ -13,31 +12,22 @@ import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; import org.opensearch.knn.quantization.quantizer.Quantizer; -import java.lang.reflect.Field; -import java.util.concurrent.atomic.AtomicBoolean; - public class QuantizerFactoryTests extends KNNTestCase { - @Before - public void resetIsRegisteredFlag() throws NoSuchFieldException, IllegalAccessException { - Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered"); - isRegisteredField.setAccessible(true); - AtomicBoolean isRegistered = (AtomicBoolean) isRegisteredField.get(null); - isRegistered.set(false); - } - public void test_Lazy_Registration() { - ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - ScalarQuantizationParams paramsTwoBit = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - ScalarQuantizationParams paramsFourBit = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - assertFalse(isRegisteredFieldAccessible()); - Quantizer oneBitQuantizer = QuantizerFactory.getQuantizer(params); - Quantizer quantizerTwoBit = QuantizerFactory.getQuantizer(paramsTwoBit); - Quantizer quantizerFourBit = QuantizerFactory.getQuantizer(paramsFourBit); - assertEquals(quantizerFourBit.getClass(), MultiBitScalarQuantizer.class); - assertEquals(quantizerTwoBit.getClass(), MultiBitScalarQuantizer.class); - assertEquals(oneBitQuantizer.getClass(), OneBitScalarQuantizer.class); - assertTrue(isRegisteredFieldAccessible()); + try { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + ScalarQuantizationParams paramsTwoBit = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + ScalarQuantizationParams paramsFourBit = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + Quantizer oneBitQuantizer = QuantizerFactory.getQuantizer(params); + Quantizer quantizerTwoBit = QuantizerFactory.getQuantizer(paramsTwoBit); + Quantizer quantizerFourBit = QuantizerFactory.getQuantizer(paramsFourBit); + assertEquals(OneBitScalarQuantizer.class, oneBitQuantizer.getClass()); + assertEquals(MultiBitScalarQuantizer.class, quantizerTwoBit.getClass()); + assertEquals(MultiBitScalarQuantizer.class, quantizerFourBit.getClass()); + } catch (Exception e) { + assertTrue(e.getMessage().contains("already registered")); + } } public void testGetQuantizer_withNullParams() { @@ -48,16 +38,4 @@ public void testGetQuantizer_withNullParams() { assertEquals("Quantization parameters must not be null.", e.getMessage()); } } - - private boolean isRegisteredFieldAccessible() { - try { - Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered"); - isRegisteredField.setAccessible(true); - AtomicBoolean isRegistered = (AtomicBoolean) isRegisteredField.get(null); - return isRegistered.get(); - } catch (NoSuchFieldException | IllegalAccessException e) { - fail("Failed to access isRegistered field."); - return false; - } - } } diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java index 62d31ab61d..7c974e5172 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java @@ -17,18 +17,22 @@ public class QuantizerRegistryTests extends KNNTestCase { @BeforeClass public static void setup() { - QuantizerRegistry.register( - ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT), - new OneBitScalarQuantizer() - ); - QuantizerRegistry.register( - ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.TWO_BIT), - new MultiBitScalarQuantizer(2) - ); - QuantizerRegistry.register( - ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT), - new MultiBitScalarQuantizer(4) - ); + try { + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT), + new OneBitScalarQuantizer() + ); + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.TWO_BIT), + new MultiBitScalarQuantizer(2) + ); + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT), + new MultiBitScalarQuantizer(4) + ); + } catch (Exception e) { + assertTrue(e.getMessage().contains("already registered")); + } } public void testRegisterAndGetQuantizer() { diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java index 35edf49e2c..298256127e 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.quantization.quantizationState; +import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; @@ -65,4 +66,76 @@ public void testSerializationWithDifferentVersions() throws IOException { assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta); assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); } + + public void testOneBitScalarQuantizationStateRamBytesUsed() throws IOException { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + float[] mean = { 1.0f, 2.0f, 3.0f }; + + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + + // 1. Manual Calculation of RAM Usage + long manualEstimatedRamBytesUsed = 0L; + + // OneBitScalarQuantizationState object overhead for Object Header + manualEstimatedRamBytesUsed += alignSize(16L); + + // ScalarQuantizationParams object overhead Object Header + manualEstimatedRamBytesUsed += alignSize(16L); + + // Mean array overhead (array header + size of elements) + manualEstimatedRamBytesUsed += alignSize(16L + 4L * mean.length); + + // 3. RAM Usage from RamUsageEstimator + long expectedRamBytesUsed = RamUsageEstimator.shallowSizeOfInstance(OneBitScalarQuantizationState.class) + RamUsageEstimator + .shallowSizeOf(params) + RamUsageEstimator.sizeOf(mean); + + long actualRamBytesUsed = state.ramBytesUsed(); + + // Allow a difference between manual estimation, serialization size, and actual RAM usage + assertTrue( + "The difference between manual and actual RAM usage exceeds 8 bytes", + Math.abs(manualEstimatedRamBytesUsed - actualRamBytesUsed) <= 8 + ); + + assertEquals(expectedRamBytesUsed, actualRamBytesUsed); + } + + public void testMultiBitScalarQuantizationStateRamBytesUsedManualCalculation() throws IOException { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + float[][] thresholds = { { 0.5f, 1.5f, 2.5f }, { 1.0f, 2.0f, 3.0f } }; + + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + // Manually estimate RAM usage with alignment + long manualEstimatedRamBytesUsed = 0L; + + // Estimate for MultiBitScalarQuantizationState object + manualEstimatedRamBytesUsed += alignSize(16L); // Example overhead for object + + // Estimate for ScalarQuantizationParams object + manualEstimatedRamBytesUsed += alignSize(16L); // Overhead for params object (including fields) + + // Estimate for thresholds array + manualEstimatedRamBytesUsed += alignSize(16L + 4L * thresholds.length); // Overhead for array + references to sub-arrays + + for (float[] row : thresholds) { + manualEstimatedRamBytesUsed += alignSize(16L + 4L * row.length); // Overhead for each sub-array + size of each float + } + + long ramEstimatorRamBytesUsed = RamUsageEstimator.shallowSizeOfInstance(MultiBitScalarQuantizationState.class) + RamUsageEstimator + .shallowSizeOf(params) + RamUsageEstimator.shallowSizeOf(thresholds); + + for (float[] row : thresholds) { + ramEstimatorRamBytesUsed += RamUsageEstimator.sizeOf(row); + } + + long difference = Math.abs(manualEstimatedRamBytesUsed - ramEstimatorRamBytesUsed); + assertTrue("The difference between manual and actual RAM usage exceeds 8 bytes", difference <= 8); + assertEquals(ramEstimatorRamBytesUsed, state.ramBytesUsed()); + } + + private long alignSize(long size) { + return (size + 7) & ~7; // Align to 8 bytes boundary + } + }