Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integration With Qunatization Config
Browse files Browse the repository at this point in the history
Signed-off-by: VIKASH TIWARI <viktari@amazon.com>
Vikasht34 committed Aug 22, 2024
1 parent c310f72 commit e9cae9b
Showing 10 changed files with 169 additions and 95 deletions.
Original file line number Diff line number Diff line change
@@ -11,11 +11,8 @@
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.security.AccessController;
@@ -57,34 +54,14 @@ public static DefaultIndexBuildStrategy getInstance() {
public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException {
// Needed to make sure we don't get 0 dimensions while initializing index
iterateVectorValuesOnce(knnVectorValues);
QuantizationService quantizationHandler = QuantizationService.getInstance();
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;
IndexBuildSetup indexBuildSetup = IndexBuildHelper.prepareIndexBuild(knnVectorValues, indexInfo);

int bytesPerVector;
int dimensions;

// Handle quantization state if present
if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}

int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector);
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {
final List<Integer> transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs());

while (knnVectorValues.docId() != NO_MORE_DOCS) {
if (quantizationState != null && quantizationOutput != null) {
quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput);
vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), true);
} else {
vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true);
}
IndexBuildHelper.processAndTransferVector(knnVectorValues, indexBuildSetup, vectorTransfer, true);
// append is true here so off heap memory buffer isn't overwritten
transferredDocIds.add(knnVectorValues.docId());
knnVectorValues.nextDoc();
@@ -100,7 +77,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.createIndexFromTemplate(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexInfo.getIndexPath(),
(byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER),
params,
@@ -113,7 +90,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.createIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexInfo.getIndexPath(),
params,
indexInfo.getKnnEngine()
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.nativeindex;

import lombok.experimental.UtilityClass;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;

@UtilityClass
class IndexBuildHelper {

/**
* Processes and transfers the vector based on whether quantization is applied or not.
*
* @param knnVectorValues the KNN vector values to be processed.
* @param indexBuildSetup the setup containing quantization state and output, along with other parameters.
* @param vectorTransfer the off-heap vector transfer utility.
* @param append flag indicating whether to append or overwrite the transfer buffer.
* @return boolean indicating whether the transfer was successful.
* @throws IOException if an I/O error occurs during vector transfer.
*/
static boolean processAndTransferVector(
KNNVectorValues<?> knnVectorValues,
IndexBuildSetup indexBuildSetup,
OffHeapVectorTransfer vectorTransfer,
boolean append
) throws IOException {
QuantizationService quantizationService = QuantizationService.getInstance();
if (indexBuildSetup.getQuantizationState() != null && indexBuildSetup.getQuantizationOutput() != null) {
quantizationService.quantize(
indexBuildSetup.getQuantizationState(),
knnVectorValues.getVector(),
indexBuildSetup.getQuantizationOutput()
);
return vectorTransfer.transfer(indexBuildSetup.getQuantizationOutput().getQuantizedVector(), append);
} else {
return vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), append);
}
}

/**
* Prepares the quantization setup including bytes per vector and dimensions.
*
* @param knnVectorValues the KNN vector values.
* @param indexInfo the index build parameters.
* @return an instance of QuantizationSetup containing relevant information.
*/
static IndexBuildSetup prepareIndexBuild(KNNVectorValues<?> knnVectorValues, BuildIndexParams indexInfo) {
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;
QuantizationService quantizationService = QuantizationService.getInstance();

int bytesPerVector;
int dimensions;

if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}

return new IndexBuildSetup(bytesPerVector, dimensions, quantizationOutput, quantizationState);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.nativeindex;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

/**
* IndexBuildSetup encapsulates the configuration and parameters required for building an index.
* This includes the size of each vector, the dimensions of the vectors, and any quantization-related
* settings such as the output and state of quantization.
*/
@Getter
@AllArgsConstructor
final class IndexBuildSetup {
/**
* The number of bytes per vector.
*/
private final int bytesPerVector;

/**
* The number of dimensions in the vector.
*/
private final int dimensions;

/**
* The quantization output that will hold the quantized vector.
*/
private final QuantizationOutput quantizationOutput;

/**
* The state of quantization, which may include parameters and trained models.
*/
private final QuantizationState quantizationState;
}
Original file line number Diff line number Diff line change
@@ -11,11 +11,8 @@
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.security.AccessController;
@@ -60,48 +57,26 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
iterateVectorValuesOnce(knnVectorValues);
KNNEngine engine = indexInfo.getKnnEngine();
Map<String, Object> indexParameters = indexInfo.getParameters();
QuantizationService quantizationHandler = QuantizationService.getInstance();
QuantizationState quantizationState = indexInfo.getQuantizationState();
QuantizationOutput quantizationOutput = null;

int bytesPerVector;
int dimensions;

// Handle quantization state if present
if (quantizationState != null) {
bytesPerVector = quantizationState.getBytesPerVector();
dimensions = quantizationState.getDimensions();
quantizationOutput = quantizationHandler.createQuantizationOutput(quantizationState.getQuantizationParams());
} else {
bytesPerVector = knnVectorValues.bytesPerVector();
dimensions = knnVectorValues.dimension();
}
IndexBuildSetup indexBuildSetup = IndexBuildHelper.prepareIndexBuild(knnVectorValues, indexInfo);

// Initialize the index
long indexMemoryAddress = AccessController.doPrivileged(
(PrivilegedAction<Long>) () -> JNIService.initIndex(
knnVectorValues.totalLiveDocs(),
knnVectorValues.dimension(),
indexBuildSetup.getDimensions(),
indexParameters,
engine
)
);

int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / bytesPerVector);
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {

final List<Integer> transferredDocIds = new ArrayList<>(transferLimit);

while (knnVectorValues.docId() != NO_MORE_DOCS) {
// append is false to be able to reuse the memory location
boolean transferred;
if (quantizationState != null && quantizationOutput != null) {
quantizationHandler.quantize(quantizationState, knnVectorValues.getVector(), quantizationOutput);
transferred = vectorTransfer.transfer(quantizationOutput.getQuantizedVector(), false);
} else {
transferred = vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), false);
}
// append is false to be able to reuse the memory location
boolean transferred = IndexBuildHelper.processAndTransferVector(knnVectorValues, indexBuildSetup, vectorTransfer, false);
transferredDocIds.add(knnVectorValues.docId());
if (transferred) {
// Insert vectors
@@ -110,7 +85,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.insertToIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexParameters,
indexMemoryAddress,
engine
@@ -130,7 +105,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
JNIService.insertToIndex(
intListToArray(transferredDocIds),
vectorAddress,
dimensions,
indexBuildSetup.getDimensions(),
indexParameters,
indexMemoryAddress,
engine
Original file line number Diff line number Diff line change
@@ -97,9 +97,7 @@ static KNNLibraryIndexingContext adjustPrefix(
// We need to update the prefix used to create the faiss index if we are using the quantization
// framework
if (encoderContext != null && Objects.equals(encoderContext.getName(), QFrameBitEncoder.NAME)) {
// TODO: Uncomment to use Quantization framework
// leaving commented now just so it wont fail creating faiss indices.
// prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
}

if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) {
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
* KNNVectorQuantizationTrainingRequest is a concrete implementation of the abstract TrainingRequest class.
* It provides a mechanism to retrieve float vectors from the KNNVectorValues by document ID.
*/
class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {
final class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {

private final KNNVectorValues<T> knnVectorValues;
private int lastIndex;
@@ -39,15 +39,13 @@ class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {
@Override
public T getVectorByDocId(int docId) {
try {
int index = lastIndex;
while (index <= docId) {
while (lastIndex <= docId) {
knnVectorValues.nextDoc();
index++;
lastIndex++;
}
if (knnVectorValues.docId() == NO_MORE_DOCS) {
return null;
}
lastIndex = index;
// Return the vector and the updated index
return knnVectorValues.getVector();
} catch (Exception e) {
Original file line number Diff line number Diff line change
@@ -9,8 +9,8 @@
import lombok.NoArgsConstructor;
import org.apache.lucene.index.FieldInfo;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.factory.QuantizerFactory;
import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
@@ -20,6 +20,8 @@
import org.opensearch.knn.quantization.quantizer.Quantizer;
import java.io.IOException;

import static org.opensearch.knn.common.FieldInfoExtractor.extractQuantizationConfig;

/**
* A singleton class responsible for handling the quantization process, including training a quantizer
* and applying quantization to vectors. This class is designed to be thread-safe.
@@ -28,7 +30,7 @@
* @param <R> The type of the quantized output vectors.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class QuantizationService<T, R> {
public final class QuantizationService<T, R> {

/**
* The singleton instance of the {@link QuantizationService} class.
@@ -85,9 +87,9 @@ public R quantize(final QuantizationState quantizationState, final T vector, fin
* Retrieves quantization parameters from the FieldInfo.
*/
public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) {
// TODO: Replace this with actual logic to extract quantization parameters from FieldInfo
if (fieldInfo.getAttribute("QuantizationConfig") != null) {
return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
return new ScalarQuantizationParams(quantizationConfig.getQuantizationType());
}
return null;
}
@@ -101,8 +103,11 @@ public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) {
* @return The {@link VectorDataType} to be used during the vector transfer process
*/
public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) {
// TODO: Replace this with actual logic to extract quantization parameters from FieldInfo
return VectorDataType.BINARY;
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
return VectorDataType.BINARY;
}
return null;
}

/**
Loading

0 comments on commit e9cae9b

Please sign in to comment.