Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
QuantizationFramework Changes
Browse files Browse the repository at this point in the history
Vikasht34 committed Jul 8, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 9efcb16 commit 68f29f8
Showing 32 changed files with 1,165 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -21,6 +21,11 @@
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesIterator;
import org.opensearch.knn.jni.JNICommons;
import org.opensearch.knn.quantization.QuantizationManager;
import org.opensearch.knn.quantization.enums.SQTypes;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.quantizer.Quantizer;

import java.io.IOException;
import java.util.ArrayList;
@@ -56,24 +61,28 @@ private static void createNativeIndex(
}

private static KNNCodecUtil.Pair streamFloatVectors(final KNNVectorValues<float[]> kNNVectorValues) throws IOException {
List<float[]> vectorList = new ArrayList<>();
List<byte[]> vectorList = new ArrayList<>();
List<Integer> docIdList = new ArrayList<>();
long vectorAddress = 0;
int dimension = 0;
long totalLiveDocs = kNNVectorValues.totalLiveDocs();
long vectorsStreamingMemoryLimit = KNNSettings.getVectorStreamingMemoryLimit().getBytes();
long vectorsPerTransfer = Integer.MIN_VALUE;

QuantizationParams params = getQuantizationParams(); // Implement this method to get appropriate params
Quantizer<float[], byte[]> quantizer = (Quantizer<float[], byte[]>) QuantizationManager.getInstance().getQuantizer(params);

KNNVectorValuesIterator iterator = kNNVectorValues.getVectorValuesIterator();

for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) {
float[] temp = kNNVectorValues.getVector();
// This temp object and copy of temp object is required because when we map floats we read to a memory
// location in heap always for floatVectorValues. Ref: OffHeapFloatVectorValues.vectorValue.
float[] vector = Arrays.copyOf(temp, temp.length);
byte[] quantizedVector = quantizer.quantize(vector).getQuantizedVector();
dimension = vector.length;
if (vectorsPerTransfer == Integer.MIN_VALUE) {
vectorsPerTransfer = (dimension * Float.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit;
vectorsPerTransfer = (dimension * Byte.BYTES * totalLiveDocs) / vectorsStreamingMemoryLimit;
// This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer
// Doing this will reduce 1 extra trip to JNI layer.
if (vectorsPerTransfer == 0) {
@@ -82,19 +91,19 @@ private static KNNCodecUtil.Pair streamFloatVectors(final KNNVectorValues<float[
}

if (vectorList.size() == vectorsPerTransfer) {
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension);
vectorAddress = JNICommons.storeByteVectorData(vectorAddress, vectorList.toArray(new byte[][] {}), totalLiveDocs * dimension);
// We should probably come up with a better way to reuse the vectorList memory which we have
// created. Problem here is doing like this can lead to a lot of list memory which is of no use and
// will be garbage collected later on, but it creates pressure on JVM. We should revisit this.
vectorList = new ArrayList<>();
}

vectorList.add(vector);
vectorList.add(quantizedVector);
docIdList.add(doc);
}

if (vectorList.isEmpty() == false) {
vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension);
vectorAddress = JNICommons.storeByteVectorData(vectorAddress, vectorList.toArray(new byte[][] {}), totalLiveDocs * dimension);
}
// SerializationMode.COLLECTION_OF_FLOATS is not getting used. I just added it to ensure code successfully
// works.
@@ -105,4 +114,9 @@ private static KNNCodecUtil.Pair streamFloatVectors(final KNNVectorValues<float[
SerializationMode.COLLECTION_OF_FLOATS
);
}

private static QuantizationParams getQuantizationParams() {
// Implement this method to return appropriate quantization parameters based on your use case
return new SQParams(SQTypes.ONE_BIT); // Example, modify as needed
}
}
20 changes: 17 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
@@ -43,6 +43,11 @@
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.knn.quantization.QuantizationManager;
import org.opensearch.knn.quantization.enums.SQTypes;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.quantizer.Quantizer;

import java.io.IOException;
import java.nio.file.Path;
@@ -154,6 +159,11 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
return convertSearchResponseToScorer(docIdsToScoreMap);
}

private QuantizationParams getQuantizationParams() {
// Implement this method to return appropriate quantization parameters based on your use case
return new SQParams(SQTypes.ONE_BIT); // Example, modify as needed
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
if (this.filterWeight == null) {
return new FixedBitSet(0);
@@ -211,6 +221,9 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader());
String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString();

QuantizationParams params = getQuantizationParams(); // Implement this method to get appropriate params
Quantizer<float[], byte[]> quantizer = (Quantizer<float[], byte[]>) QuantizationManager.getInstance().getQuantizer(params);

FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());

if (fieldInfo == null) {
@@ -272,7 +285,7 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
spaceType,
knnEngine,
knnQuery.getIndexName(),
FieldInfoExtractor.getIndexDescription(fieldInfo)
"B" + FieldInfoExtractor.getIndexDescription(fieldInfo)
),
knnQuery.getIndexName(),
modelId
@@ -295,10 +308,11 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
throw new RuntimeException("Index has already been closed");
}
int[] parentIds = getParentIdsArray(context);
byte[] quantizedVector = quantizer.quantize(knnQuery.getQueryVector()).getQuantizedVector();
if (knnQuery.getK() > 0) {
results = JNIService.queryIndex(
results = JNIService.queryBinaryIndex(
indexAllocation.getMemoryAddress(),
knnQuery.getQueryVector(),
quantizedVector,
knnQuery.getK(),
knnEngine,
filterIds,
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
@@ -52,6 +52,11 @@ public static void createIndex(
}

if (KNNEngine.FAISS == knnEngine) {
if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) {
String indexDesc = (String) parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER);
parameters.put(KNNConstants.INDEX_DESCRIPTION_PARAMETER ,"B" + indexDesc);

}
if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null
&& parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_PREFIX)) {
FaissService.createBinaryIndex(ids, vectorsAddress, dim, indexPath, parameters);
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization;

import org.opensearch.knn.quantization.factory.QuantizerFactory;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.SamplingTrainingRequest;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.quantizer.Quantizer;
import org.opensearch.knn.quantization.sampler.Sampler;
import org.opensearch.knn.quantization.sampler.SamplingFactory;

public class QuantizationManager {
private static QuantizationManager instance;

private QuantizationManager() {}

public static QuantizationManager getInstance() {
if (instance == null) {
instance = new QuantizationManager();
}
return instance;
}
public <T, R> QuantizationState train(TrainingRequest<T> trainingRequest) {
Quantizer<T, R> quantizer = (Quantizer<T, R>) getQuantizer(trainingRequest.getParams());
int sampleSize = quantizer.getSamplingSize();
Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
TrainingRequest<T> sampledRequest = new SamplingTrainingRequest<>(trainingRequest, sampler, sampleSize);
return quantizer.train(sampledRequest);
}
public Quantizer<?, ?> getQuantizer(QuantizationParams params) {
return QuantizerFactory.getQuantizer(params);
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.enums;

public enum QuantizationType {
SPACE_QUANTIZATION,
VALUE_QUANTIZATION,
}
21 changes: 21 additions & 0 deletions src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.enums;

public enum SQTypes {
FP16,
INT8,
INT6,
INT4,
ONE_BIT,
TWO_BIT
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.enums;

public enum ValueQuantizationType {
SQ
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.factory;

import org.opensearch.knn.quantization.enums.SQTypes;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
import org.opensearch.knn.quantization.quantizer.Quantizer;

public class QuantizerFactory {
static {
// Register all quantizers here
QuantizerRegistry.register(SQParams.class, SQTypes.ONE_BIT.name(), OneBitScalarQuantizer::new);
}

public static Quantizer<?, ?> getQuantizer(QuantizationParams params) {
if (params instanceof SQParams) {
SQParams sqParams = (SQParams) params;
return QuantizerRegistry.getQuantizer(params, sqParams.getSqType().name());
}
// Add more cases for other quantization parameters here
throw new IllegalArgumentException("Unsupported quantization parameters: " + params.getClass().getName());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.factory;

import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.quantizer.Quantizer;

import java.util.HashMap;
import java.util.Map;
import java.util.function.Supplier;

public class QuantizerRegistry {
private static final Map<Class<? extends QuantizationParams>, Map<String, Supplier<? extends Quantizer<?, ?>>>> registry = new HashMap<>();

public static <T extends QuantizationParams> void register(Class<T> paramClass, String typeIdentifier, Supplier<? extends Quantizer<?, ?>> quantizerSupplier) {
registry.computeIfAbsent(paramClass, k -> new HashMap<>()).put(typeIdentifier, quantizerSupplier);
}

public static Quantizer<?, ?> getQuantizer(QuantizationParams params, String typeIdentifier) {
Map<String, Supplier<? extends Quantizer<?, ?>>> typeMap = registry.get(params.getClass());
if (typeMap == null) {
throw new IllegalArgumentException("No quantizer registered for parameters: " + params.getClass().getName());
}
Supplier<? extends Quantizer<?, ?>> supplier = typeMap.get(typeIdentifier);
if (supplier == null) {
throw new IllegalArgumentException("No quantizer registered for type identifier: " + typeIdentifier);
}
return supplier.get();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.models.quantizationOutput;

public class OneBitScalarQuantizationOutput extends QuantizationOutput<byte[]> {

public OneBitScalarQuantizationOutput(byte[] quantizedVector) {
super(quantizedVector);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.models.quantizationOutput;

public abstract class QuantizationOutput<T> {
private final T quantizedVector;

public QuantizationOutput(T quantizedVector) {
this.quantizedVector = quantizedVector;
}

public T getQuantizedVector() {
return quantizedVector;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.models.quantizationParams;

import java.io.Serializable;
import org.opensearch.knn.quantization.enums.QuantizationType;

public abstract class QuantizationParams implements Serializable {
private QuantizationType quantizationType;

public QuantizationParams(QuantizationType quantizationType) {
this.quantizationType = quantizationType;
}

public QuantizationType getQuantizationType() {
return quantizationType;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.models.quantizationParams;

import org.opensearch.knn.quantization.enums.QuantizationType;
import org.opensearch.knn.quantization.enums.SQTypes;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;

public class SQParams extends QuantizationParams {
private SQTypes sqType;

public SQParams(SQTypes sqType) {
super(QuantizationType.VALUE_QUANTIZATION);
this.sqType = sqType;
}
public SQTypes getSqType() {
return sqType;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.models.quantizationState;


import org.opensearch.knn.quantization.models.quantizationParams.SQParams;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.ByteArrayInputStream;
import java.io.ObjectInputStream;

public class OneBitScalarQuantizationState extends QuantizationState {
private float[] mean;

public OneBitScalarQuantizationState(SQParams quantizationParams, float[] floatArray) {
super(quantizationParams);
this.mean = floatArray;
}

public float[] getMean() {
return mean;
}

@Override
public byte[] toByteArray() throws IOException {
byte[] parentBytes = super.toByteArray();
ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream out = new ObjectOutputStream(bos);
out.write(parentBytes);
out.writeObject(mean);
out.flush();
return bos.toByteArray();
}

public static OneBitScalarQuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException {
ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
ObjectInputStream in = new ObjectInputStream(bis);
QuantizationState parentState = (QuantizationState) in.readObject();
float[] floatArray = (float[]) in.readObject();
return new OneBitScalarQuantizationState((SQParams) parentState.getQuantizationParams(), floatArray);
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.models.quantizationState;

import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.io.ByteArrayInputStream;
import java.io.ObjectInputStream;

public abstract class QuantizationState implements Serializable {
private QuantizationParams quantizationParams;

public QuantizationState(QuantizationParams quantizationParams) {
this.quantizationParams = quantizationParams;
}

public QuantizationParams getQuantizationParams() {
return quantizationParams;
}

public byte[] toByteArray() throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream out = new ObjectOutputStream(bos);
out.writeObject(this);
out.flush();
return bos.toByteArray();
}

public static QuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException {
ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
ObjectInputStream in = new ObjectInputStream(bis);
return (QuantizationState) in.readObject();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.models.requests;

import org.opensearch.knn.quantization.sampler.Sampler;

import java.util.List;

public class SamplingTrainingRequest<T> extends TrainingRequest<T> {
private TrainingRequest<T> originalRequest;
private int[] sampledIndices;

public SamplingTrainingRequest(TrainingRequest<T> originalRequest, Sampler sampler, int sampleSize) {
super(originalRequest.getParams(), originalRequest.getTotalNumberOfVectors());
this.originalRequest = originalRequest;
this.sampledIndices = sampler.sample(originalRequest.getTotalNumberOfVectors(), sampleSize);
}

@Override
public T getVector() {
return originalRequest.getVector();
}

@Override
public T getVectorByDocId(int docId) {
return originalRequest.getVectorByDocId(docId);
}

public int[] getSampledIndices() {
return sampledIndices;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.models.requests;

import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;

public abstract class TrainingRequest<T> {
private QuantizationParams params;
private int totalNumberOfVectors;

public TrainingRequest(QuantizationParams params, int totalNumberOfVectors) {
this.params = params;
this.totalNumberOfVectors = totalNumberOfVectors;
}

public QuantizationParams getParams() {
return params;
}

public int getTotalNumberOfVectors() {
return totalNumberOfVectors;
}

public abstract T getVector();

public abstract T getVectorByDocId(int docId);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.quantizer;

import org.opensearch.knn.quantization.models.quantizationOutput.OneBitScalarQuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.SamplingTrainingRequest;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;

public class OneBitScalarQuantizer implements Quantizer<float[], byte[]> {
private static final int SAMPLING_SIZE = 25000;

@Override
public int getSamplingSize() {
return SAMPLING_SIZE;
}

@Override
public QuantizationState train(TrainingRequest<float[]> trainingRequest) {
if (!(trainingRequest instanceof SamplingTrainingRequest)) {
throw new IllegalArgumentException("Training request must be of type SamplingTrainingRequest.");
}

SamplingTrainingRequest<float[]> samplingRequest = (SamplingTrainingRequest<float[]>) trainingRequest;
int[] sampledIndices = samplingRequest.getSampledIndices();

if (sampledIndices == null || sampledIndices.length == 0) {
throw new IllegalArgumentException("Sampled indices must not be null or empty.");
}

int totalSamples = sampledIndices.length;
float[] sum = null;

// Calculate the sum for each dimension based on sampled indices
for (int i = 0; i < totalSamples; i++) {
float[] vector = samplingRequest.getVectorByDocId(sampledIndices[i]);
if (vector == null) {
throw new IllegalArgumentException("Vector at sampled index " + sampledIndices[i] + " is null.");
}
if (sum == null) {
sum = new float[vector.length];
} else if (sum.length != vector.length) {
throw new IllegalArgumentException("All vectors must have the same dimension.");
}
for (int j = 0; j < vector.length; j++) {
sum[j] += vector[j];
}
}
if (sum == null) {
throw new IllegalStateException("Sum array should not be null after processing vectors.");
}
// Calculate the mean for each dimension
float[] mean = new float[sum.length];
for (int j = 0; j < sum.length; j++) {
mean[j] = sum[j] / totalSamples;
}
SQParams params = (SQParams) trainingRequest.getParams();
if (params == null) {
throw new IllegalArgumentException("Quantization parameters must not be null.");
}
return new OneBitScalarQuantizationState(params, mean);
}

@Override
public QuantizationOutput<byte[]> quantize(float[] vector, QuantizationState state) {
if (vector == null) {
throw new IllegalArgumentException("Vector to quantize must not be null.");
}
if (!(state instanceof OneBitScalarQuantizationState)) {
throw new IllegalArgumentException("Quantization state must be of type OneBitScalarQuantizationState.");
}
OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState) state;
float[] thresholds = binaryState.getMean();
if (thresholds == null || thresholds.length != vector.length) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
byte[] quantizedVector = new byte[vector.length];
for (int i = 0; i < vector.length; i++) {
quantizedVector[i] = (byte) (vector[i] > thresholds[i] ? 1 : 0);
}
return new OneBitScalarQuantizationOutput(packBitsFromBitArray(quantizedVector));
}

@Override
public QuantizationOutput<byte[]> quantize(float[] vector) {
if (vector == null) {
throw new IllegalArgumentException("Vector to quantize must not be null.");
}
byte[] quantizedVector = new byte[vector.length];
for (int i = 0; i < vector.length; i++) {
quantizedVector[i] = (byte) (vector[i] > 0 ? 1 : 0);
}
return new OneBitScalarQuantizationOutput(packBitsFromBitArray(quantizedVector));
}

private byte[] packBitsFromBitArray(byte[] bitArray) {
int bitLength = bitArray.length;
int byteLength = (bitLength + 7) / 8;
byte[] packedArray = new byte[byteLength];

for (int i = 0; i < bitLength; i++) {
if (bitArray[i] != 0 && bitArray[i] != 1) {
throw new IllegalArgumentException("Array elements must be 0 or 1");
}
int byteIndex = i / 8;
int bitIndex = 7 - (i % 8);
packedArray[byteIndex] |= (bitArray[i] << bitIndex);
}

return packedArray;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.quantizer;

import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
public interface Quantizer<T, R> {
int getSamplingSize();

default QuantizationState train(TrainingRequest<T> trainingRequest) {
throw new UnsupportedOperationException("Train method is not supported by this quantizer.");
}

default QuantizationOutput<R> quantize(T vector, QuantizationState state) {
throw new UnsupportedOperationException("Quantize method with state is not supported by this quantizer.");
}

default QuantizationOutput<R> quantize(T vector) {
throw new UnsupportedOperationException("Quantize method without state is not supported by this quantizer.");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.sampler;

import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class ReservoirSampler implements Sampler {
private final Random random = new Random();

@Override
public int[] sample(int totalNumberOfVectors, int sampleSize) {
if (totalNumberOfVectors <= sampleSize) {
return IntStream.range(0, totalNumberOfVectors).toArray();
}
return reservoirSampleIndices(totalNumberOfVectors, sampleSize);
}
private int[] reservoirSampleIndices(int numVectors, int sampleSize) {
int[] indices = IntStream.range(0, sampleSize).toArray();
for (int i = sampleSize; i < numVectors; i++) {
int j = random.nextInt(i + 1);
if (j < sampleSize) {
indices[j] = i;
}
}
Arrays.sort(indices);
return indices;
}
}
18 changes: 18 additions & 0 deletions src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.sampler;

import java.util.List;

public interface Sampler {
int[] sample(int totalNumberOfVectors, int sampleSize);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.sampler;

public class SamplingFactory {
public enum SamplerType {
RESERVOIR,
}

public static Sampler getSampler(SamplerType samplerType) {
switch (samplerType) {
case RESERVOIR:
return new ReservoirSampler();
// Add more cases for different samplers
default:
throw new IllegalArgumentException("Unsupported sampler type: " + samplerType);
}
}
}
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@

import org.opensearch.knn.KNNTestCase;

public class JNICommonsTest extends KNNTestCase {
public class JNICommonsTests extends KNNTestCase {

public void testStoreVectorData_whenVaildInputThenSuccess() {
float[][] data = new float[2][2];
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.quantization.enums.SQTypes;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.quantizer.Quantizer;


public class QuantizationManagerTests extends KNNTestCase {
public void testSingletonInstance() {
QuantizationManager instance1 = QuantizationManager.getInstance();
QuantizationManager instance2 = QuantizationManager.getInstance();
assertSame(instance1, instance2);
}

public void testTrain() {
QuantizationManager quantizationManager = QuantizationManager.getInstance();
float[][] vectors = {
{1.0f, 2.0f, 3.0f},
{4.0f, 5.0f, 6.0f},
{7.0f, 8.0f, 9.0f}
};

SQParams params = new SQParams(SQTypes.ONE_BIT);
TrainingRequest<float[]> originalRequest = new TrainingRequest<float[]>(params, vectors.length) {
@Override
public float[] getVector() {
return null; // Not used in this test
}
@Override
public float[] getVectorByDocId(int docId) {
return vectors[docId];
}
};
QuantizationState state = quantizationManager.train(originalRequest);

assertTrue(state instanceof OneBitScalarQuantizationState);
float[] mean = ((OneBitScalarQuantizationState) state).getMean();
assertArrayEquals(new float[]{4.0f, 5.0f, 6.0f}, mean, 0.001f);
}

public void testTrainWithFewVectors() {
QuantizationManager quantizationManager = QuantizationManager.getInstance();
float[][] vectors = {
{1.0f, 2.0f, 3.0f},
{4.0f, 5.0f, 6.0f}
};

SQParams params = new SQParams(SQTypes.ONE_BIT);
TrainingRequest<float[]> originalRequest = new TrainingRequest<float[]>(params, vectors.length) {
@Override
public float[] getVector() {
return null; // Not used in this test
}

@Override
public float[] getVectorByDocId(int docId) {
return vectors[docId];
}
};

QuantizationState state = quantizationManager.train(originalRequest);

assertTrue(state instanceof OneBitScalarQuantizationState);
float[] mean = ((OneBitScalarQuantizationState) state).getMean();
assertArrayEquals(new float[]{2.5f, 3.5f, 4.5f}, mean, 0.001f);
}


public void testGetQuantizer() {
QuantizationManager quantizationManager = QuantizationManager.getInstance();
SQParams params = new SQParams(SQTypes.ONE_BIT);

Quantizer<?, ?> quantizer = quantizationManager.getQuantizer(params);

assertTrue(quantizer instanceof Quantizer);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.enums;

import org.opensearch.knn.KNNTestCase;

public class QuantizationTypeTests extends KNNTestCase {

public void testQuantizationTypeValues() {
QuantizationType[] expectedValues = {
QuantizationType.SPACE_QUANTIZATION,
QuantizationType.VALUE_QUANTIZATION
};
assertArrayEquals(expectedValues, QuantizationType.values());
}

public void testQuantizationTypeValueOf() {
assertEquals(QuantizationType.SPACE_QUANTIZATION, QuantizationType.valueOf("SPACE_QUANTIZATION"));
assertEquals(QuantizationType.VALUE_QUANTIZATION, QuantizationType.valueOf("VALUE_QUANTIZATION"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.enums;

import org.opensearch.knn.KNNTestCase;

public class SQTypesTests extends KNNTestCase {
public void testSQTypesValues() {
SQTypes[] expectedValues = {
SQTypes.FP16,
SQTypes.INT8,
SQTypes.INT6,
SQTypes.INT4,
SQTypes.ONE_BIT,
SQTypes.TWO_BIT
};
assertArrayEquals(expectedValues, SQTypes.values());
}

public void testSQTypesValueOf() {
assertEquals(SQTypes.FP16, SQTypes.valueOf("FP16"));
assertEquals(SQTypes.INT8, SQTypes.valueOf("INT8"));
assertEquals(SQTypes.INT6, SQTypes.valueOf("INT6"));
assertEquals(SQTypes.INT4, SQTypes.valueOf("INT4"));
assertEquals(SQTypes.ONE_BIT, SQTypes.valueOf("ONE_BIT"));
assertEquals(SQTypes.TWO_BIT, SQTypes.valueOf("TWO_BIT"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.enums;

import org.opensearch.knn.KNNTestCase;

public class ValueQuantizationTypeTests extends KNNTestCase {
public void testValueQuantizationTypeValues() {
ValueQuantizationType[] expectedValues = {
ValueQuantizationType.SQ
};
assertArrayEquals(expectedValues, ValueQuantizationType.values());
}

public void testValueQuantizationTypeValueOf() {
assertEquals(ValueQuantizationType.SQ, ValueQuantizationType.valueOf("SQ"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.factory;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.quantization.enums.SQTypes;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
import org.opensearch.knn.quantization.quantizer.Quantizer;

public class QuantizerFactoryTests extends KNNTestCase {
public void testGetQuantizer_withSQParams() {
SQParams params = new SQParams(SQTypes.ONE_BIT);
Quantizer<?, ?> quantizer = QuantizerFactory.getQuantizer(params);
assertTrue(quantizer instanceof OneBitScalarQuantizer);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.factory;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.quantization.enums.SQTypes;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
import org.opensearch.knn.quantization.quantizer.Quantizer;
import org.junit.BeforeClass;

public class QuantizerRegistryTests extends KNNTestCase {
@BeforeClass
public static void setup() {
// Register the quantizer for testing
QuantizerRegistry.register(SQParams.class, SQTypes.ONE_BIT.name(), OneBitScalarQuantizer::new);
}

public void testRegisterAndGetQuantizer() {
SQParams params = new SQParams(SQTypes.ONE_BIT);
Quantizer<?, ?> quantizer = QuantizerRegistry.getQuantizer(params, SQTypes.ONE_BIT.name());
assertTrue(quantizer instanceof OneBitScalarQuantizer);
}

public void testGetQuantizer_withUnsupportedTypeIdentifier() {
SQParams params = new SQParams(SQTypes.ONE_BIT);
expectThrows( IllegalArgumentException.class, ()-> QuantizerRegistry.getQuantizer(params, "UNSUPPORTED_TYPE"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.quantizer;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.quantization.enums.SQTypes;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.requests.SamplingTrainingRequest;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.sampler.ReservoirSampler;

public class OneBitScalarQuantizerTests extends KNNTestCase {

public void testTrain() {
float[][] vectors = {
{1.0f, 2.0f, 3.0f},
{4.0f, 5.0f, 6.0f},
{7.0f, 8.0f, 9.0f}
};

SQParams params = new SQParams(SQTypes.ONE_BIT);
TrainingRequest<float[]> originalRequest = new TrainingRequest<float[]>(params, vectors.length) {
@Override
public float[] getVector() {
return null; // Not used in this test
}
@Override
public float[] getVectorByDocId(int docId) {
return vectors[docId];
}
};
TrainingRequest<float[]> trainingRequest = new SamplingTrainingRequest<>(
originalRequest,
new ReservoirSampler(),
vectors.length
);
OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
QuantizationState state = quantizer.train(trainingRequest);

assertTrue(state instanceof OneBitScalarQuantizationState);
float[] mean = ((OneBitScalarQuantizationState) state).getMean();
assertArrayEquals(new float[]{4.0f, 5.0f, 6.0f}, mean, 0.001f);
}

public void testQuantize_withState() {
float[] vector = {3.0f, 6.0f, 9.0f};
float[] thresholds = {4.0f, 5.0f, 6.0f};
OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(SQTypes.ONE_BIT), thresholds);

OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
QuantizationOutput<byte[]> output = quantizer.quantize(vector, state);

assertArrayEquals(new byte[]{96}, output.getQuantizedVector());
}

public void testQuantize_withoutState() {
float[] vector = {-1.0f, 0.5f, 1.5f};

OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
QuantizationOutput<byte[]> output = quantizer.quantize(vector);

assertArrayEquals(new byte[]{96}, output.getQuantizedVector());
}

public void testQuantize_withNullVector() {
OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
expectThrows( IllegalArgumentException.class, ()-> quantizer.quantize(null));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.sampler;

import org.opensearch.knn.KNNTestCase;


public class ReservoirSamplerTests extends KNNTestCase {

public void testSample() {
Sampler sampler = new ReservoirSampler();
int totalNumberOfVectors = 100;
int sampleSize = 10;

int[] samples = sampler.sample(totalNumberOfVectors, sampleSize);
assertEquals(sampleSize, samples.length);
for (int index : samples) {
assertTrue(index >= 0 && index < totalNumberOfVectors);
}
}

public void testSample_withFullSampling() {
Sampler sampler = new ReservoirSampler();
int totalNumberOfVectors = 10;
int sampleSize = 10;

int[] samples = sampler.sample(totalNumberOfVectors, sampleSize);
assertEquals(sampleSize, samples.length);
for (int index : samples) {
assertTrue(index >= 0 && index < totalNumberOfVectors);
}
}

public void testSample_withLessVectors() {
Sampler sampler = new ReservoirSampler();
int totalNumberOfVectors = 5;
int sampleSize = 10;

int[] samples = sampler.sample(totalNumberOfVectors, sampleSize);
assertEquals(totalNumberOfVectors, samples.length);
for (int index : samples) {
assertTrue(index >= 0 && index < totalNumberOfVectors);
}
}

public void testSample_withZeroVectors() {
Sampler sampler = new ReservoirSampler();
int totalNumberOfVectors = 0;
int sampleSize = 10;

int[] samples = sampler.sample(totalNumberOfVectors, sampleSize);
assertEquals(0, samples.length);
}

public void testSample_withOneVector() {
Sampler sampler = new ReservoirSampler();
int totalNumberOfVectors = 1;
int sampleSize = 10;

int[] samples = sampler.sample(totalNumberOfVectors, sampleSize);
assertEquals(1, samples.length);
assertTrue(samples[0] == 0);
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.quantization.sampler;

import org.opensearch.knn.KNNTestCase;

public class SamplingFactoryTests extends KNNTestCase {
public void testGetSampler_withReservoir() {
Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR);
assertTrue(sampler instanceof ReservoirSampler);
}

public void testGetSampler_withUnsupportedType() {
expectThrows( NullPointerException.class, ()-> SamplingFactory.getSampler(null)); // This should throw an exception
}
}

0 comments on commit 68f29f8

Please sign in to comment.