Skip to content

Commit

Permalink
Integration With Qunatization Config
Browse files Browse the repository at this point in the history
Signed-off-by: VIKASH TIWARI <[email protected]>
  • Loading branch information
Vikasht34 committed Aug 22, 2024
1 parent c310f72 commit e3a6d17
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import lombok.NoArgsConstructor;
import org.apache.lucene.index.FieldInfo;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.factory.QuantizerFactory;
import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
Expand All @@ -20,6 +20,8 @@
import org.opensearch.knn.quantization.quantizer.Quantizer;
import java.io.IOException;

import static org.opensearch.knn.common.FieldInfoExtractor.extractQuantizationConfig;

/**
* A singleton class responsible for handling the quantization process, including training a quantizer
* and applying quantization to vectors. This class is designed to be thread-safe.
Expand Down Expand Up @@ -85,9 +87,9 @@ public R quantize(final QuantizationState quantizationState, final T vector, fin
* Retrieves quantization parameters from the FieldInfo.
*/
public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) {
// TODO: Replace this with actual logic to extract quantization parameters from FieldInfo
if (fieldInfo.getAttribute("QuantizationConfig") != null) {
return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
return new ScalarQuantizationParams(quantizationConfig.getQuantizationType());
}
return null;
}
Expand All @@ -101,8 +103,11 @@ public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) {
* @return The {@link VectorDataType} to be used during the vector transfer process
*/
public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) {
// TODO: Replace this with actual logic to extract quantization parameters from FieldInfo
return VectorDataType.BINARY;
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) {
return VectorDataType.BINARY;
}
return null;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -113,7 +116,8 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc
float[] floatVectorForBinaryQuantization_2 = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f };
FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT);
fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}");
fieldTypeForBinaryQuantization.putAttribute("QuantizationConfig", "{ \"type\": \"Binary\" }");
QuantizationConfig quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build();
fieldTypeForBinaryQuantization.putAttribute(KNNConstants.QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig));
fieldTypeForBinaryQuantization.freeze();

addFieldToIndex(
Expand Down Expand Up @@ -187,7 +191,8 @@ public void testNativeEngineVectorFormat_whenBinaryQuantizationApplied_thenSucce
float[] floatVectorForBinaryQuantization = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f };
FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT);
fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}");
fieldTypeForBinaryQuantization.putAttribute("QuantizationConfig", "{ \"type\": \"Binary\" }");
QuantizationConfig quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build();
fieldTypeForBinaryQuantization.putAttribute(KNNConstants.QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig));

addFieldToIndex(
new KnnFloatVectorField(FLOAT_VECTOR_FIELD_BINARY, floatVectorForBinaryQuantization, fieldTypeForBinaryQuantization),
Expand Down

0 comments on commit e3a6d17

Please sign in to comment.