From 68f29f8351cef1c4e2581514e6f6a9f77dab4239 Mon Sep 17 00:00:00 2001 From: VIKASH TIWARI Date: Tue, 2 Jul 2024 11:22:43 -0700 Subject: [PATCH] QuantizationFramework Changes --- .../knn/index/NativeIndexCreationManager.java | 24 +++- .../opensearch/knn/index/query/KNNWeight.java | 20 ++- .../org/opensearch/knn/jni/JNIService.java | 5 + .../knn/quantization/QuantizationManager.java | 45 +++++++ .../quantization/enums/QuantizationType.java | 17 +++ .../knn/quantization/enums/SQTypes.java | 21 +++ .../enums/ValueQuantizationType.java | 17 +++ .../factory/QuantizerFactory.java | 34 +++++ .../factory/QuantizerRegistry.java | 39 ++++++ .../OneBitScalarQuantizationOutput.java | 19 +++ .../QuantizationOutput.java | 24 ++++ .../QuantizationParams.java | 27 ++++ .../models/quantizationParams/SQParams.java | 28 ++++ .../OneBitScalarQuantizationState.java | 54 ++++++++ .../quantizationState/QuantizationState.java | 47 +++++++ .../requests/SamplingTrainingRequest.java | 41 ++++++ .../models/requests/TrainingRequest.java | 37 ++++++ .../quantizer/OneBitScalarQuantizer.java | 125 ++++++++++++++++++ .../knn/quantization/quantizer/Quantizer.java | 31 +++++ .../sampler/ReservoirSampler.java | 41 ++++++ .../knn/quantization/sampler/Sampler.java | 18 +++ .../quantization/sampler/SamplingFactory.java | 28 ++++ ...ICommonsTest.java => JNICommonsTests.java} | 2 +- .../QuantizationManagerTests.java | 92 +++++++++++++ .../enums/QuantizationTypeTests.java | 30 +++++ .../knn/quantization/enums/SQTypesTests.java | 37 ++++++ .../enums/ValueQuantizationTypeTests.java | 27 ++++ .../factory/QuantizerFactoryTests.java | 26 ++++ .../factory/QuantizerRegistryTests.java | 38 ++++++ .../quantizer/OneBitScalarQuantizerTests.java | 81 ++++++++++++ .../sampler/ReservoirSamplerTests.java | 74 +++++++++++ .../sampler/SamplingFactoryTests.java | 25 ++++ 32 files changed, 1165 insertions(+), 9 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/quantization/QuantizationManager.java create mode 100644 src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java create mode 100644 src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java create mode 100644 src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java create mode 100644 src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java create mode 100644 src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/OneBitScalarQuantizationOutput.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/requests/SamplingTrainingRequest.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java create mode 100644 src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java create mode 100644 src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java create mode 100644 src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java create mode 100644 src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java create mode 100644 src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java rename src/test/java/org/opensearch/knn/jni/{JNICommonsTest.java => JNICommonsTests.java} (95%) create mode 100644 src/test/java/org/opensearch/knn/quantization/QuantizationManagerTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java diff --git a/src/main/java/org/opensearch/knn/index/NativeIndexCreationManager.java b/src/main/java/org/opensearch/knn/index/NativeIndexCreationManager.java index aa727a700..af4854090 100644 --- a/src/main/java/org/opensearch/knn/index/NativeIndexCreationManager.java +++ b/src/main/java/org/opensearch/knn/index/NativeIndexCreationManager.java @@ -21,6 +21,11 @@ import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesIterator; import org.opensearch.knn.jni.JNICommons; +import org.opensearch.knn.quantization.QuantizationManager; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.quantizer.Quantizer; import java.io.IOException; import java.util.ArrayList; @@ -56,7 +61,7 @@ private static void createNativeIndex( } private static KNNCodecUtil.Pair streamFloatVectors(final KNNVectorValues kNNVectorValues) throws IOException { - List vectorList = new ArrayList<>(); + List vectorList = new ArrayList<>(); List docIdList = new ArrayList<>(); long vectorAddress = 0; int dimension = 0; @@ -64,6 +69,9 @@ private static KNNCodecUtil.Pair streamFloatVectors(final KNNVectorValues quantizer = (Quantizer) QuantizationManager.getInstance().getQuantizer(params); + KNNVectorValuesIterator iterator = kNNVectorValues.getVectorValuesIterator(); for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { @@ -71,9 +79,10 @@ private static KNNCodecUtil.Pair streamFloatVectors(final KNNVectorValues(); } - vectorList.add(vector); + vectorList.add(quantizedVector); docIdList.add(doc); } if (vectorList.isEmpty() == false) { - vectorAddress = JNICommons.storeVectorData(vectorAddress, vectorList.toArray(new float[][] {}), totalLiveDocs * dimension); + vectorAddress = JNICommons.storeByteVectorData(vectorAddress, vectorList.toArray(new byte[][] {}), totalLiveDocs * dimension); } // SerializationMode.COLLECTION_OF_FLOATS is not getting used. I just added it to ensure code successfully // works. @@ -105,4 +114,9 @@ private static KNNCodecUtil.Pair streamFloatVectors(final KNNVectorValues doANNSearch(final LeafReaderContext context, final B SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader()); String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString(); + QuantizationParams params = getQuantizationParams(); // Implement this method to get appropriate params + Quantizer quantizer = (Quantizer) QuantizationManager.getInstance().getQuantizer(params); + FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); if (fieldInfo == null) { @@ -272,7 +285,7 @@ private Map doANNSearch(final LeafReaderContext context, final B spaceType, knnEngine, knnQuery.getIndexName(), - FieldInfoExtractor.getIndexDescription(fieldInfo) + "B" + FieldInfoExtractor.getIndexDescription(fieldInfo) ), knnQuery.getIndexName(), modelId @@ -295,10 +308,11 @@ private Map doANNSearch(final LeafReaderContext context, final B throw new RuntimeException("Index has already been closed"); } int[] parentIds = getParentIdsArray(context); + byte[] quantizedVector = quantizer.quantize(knnQuery.getQueryVector()).getQuantizedVector(); if (knnQuery.getK() > 0) { - results = JNIService.queryIndex( + results = JNIService.queryBinaryIndex( indexAllocation.getMemoryAddress(), - knnQuery.getQueryVector(), + quantizedVector, knnQuery.getK(), knnEngine, filterIds, diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index e7689d9cc..af304fa09 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -52,6 +52,11 @@ public static void createIndex( } if (KNNEngine.FAISS == knnEngine) { + if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null) { + String indexDesc = (String) parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER); + parameters.put(KNNConstants.INDEX_DESCRIPTION_PARAMETER ,"B" + indexDesc); + + } if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) != null && parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_PREFIX)) { FaissService.createBinaryIndex(ids, vectorsAddress, dim, indexPath, parameters); diff --git a/src/main/java/org/opensearch/knn/quantization/QuantizationManager.java b/src/main/java/org/opensearch/knn/quantization/QuantizationManager.java new file mode 100644 index 000000000..10213f053 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/QuantizationManager.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization; + +import org.opensearch.knn.quantization.factory.QuantizerFactory; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.SamplingTrainingRequest; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.quantizer.Quantizer; +import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplingFactory; + +public class QuantizationManager { + private static QuantizationManager instance; + + private QuantizationManager() {} + + public static QuantizationManager getInstance() { + if (instance == null) { + instance = new QuantizationManager(); + } + return instance; + } + public QuantizationState train(TrainingRequest trainingRequest) { + Quantizer quantizer = (Quantizer) getQuantizer(trainingRequest.getParams()); + int sampleSize = quantizer.getSamplingSize(); + Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); + TrainingRequest sampledRequest = new SamplingTrainingRequest<>(trainingRequest, sampler, sampleSize); + return quantizer.train(sampledRequest); + } + public Quantizer getQuantizer(QuantizationParams params) { + return QuantizerFactory.getQuantizer(params); + } +} + diff --git a/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java new file mode 100644 index 000000000..9031ec486 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java @@ -0,0 +1,17 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.enums; + +public enum QuantizationType { + SPACE_QUANTIZATION, + VALUE_QUANTIZATION, +} diff --git a/src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java b/src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java new file mode 100644 index 000000000..a7f4f59e6 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java @@ -0,0 +1,21 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.enums; + +public enum SQTypes { + FP16, + INT8, + INT6, + INT4, + ONE_BIT, + TWO_BIT +} diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java new file mode 100644 index 000000000..ec5cb1814 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java @@ -0,0 +1,17 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.enums; + +public enum ValueQuantizationType { + SQ +} + diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java new file mode 100644 index 000000000..b6a325973 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.factory; + +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +public class QuantizerFactory { + static { + // Register all quantizers here + QuantizerRegistry.register(SQParams.class, SQTypes.ONE_BIT.name(), OneBitScalarQuantizer::new); + } + + public static Quantizer getQuantizer(QuantizationParams params) { + if (params instanceof SQParams) { + SQParams sqParams = (SQParams) params; + return QuantizerRegistry.getQuantizer(params, sqParams.getSqType().name()); + } + // Add more cases for other quantization parameters here + throw new IllegalArgumentException("Unsupported quantization parameters: " + params.getClass().getName()); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java new file mode 100644 index 000000000..e25ea029c --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.factory; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Supplier; + +public class QuantizerRegistry { + private static final Map, Map>>> registry = new HashMap<>(); + + public static void register(Class paramClass, String typeIdentifier, Supplier> quantizerSupplier) { + registry.computeIfAbsent(paramClass, k -> new HashMap<>()).put(typeIdentifier, quantizerSupplier); + } + + public static Quantizer getQuantizer(QuantizationParams params, String typeIdentifier) { + Map>> typeMap = registry.get(params.getClass()); + if (typeMap == null) { + throw new IllegalArgumentException("No quantizer registered for parameters: " + params.getClass().getName()); + } + Supplier> supplier = typeMap.get(typeIdentifier); + if (supplier == null) { + throw new IllegalArgumentException("No quantizer registered for type identifier: " + typeIdentifier); + } + return supplier.get(); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/OneBitScalarQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/OneBitScalarQuantizationOutput.java new file mode 100644 index 000000000..458cff89a --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/OneBitScalarQuantizationOutput.java @@ -0,0 +1,19 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.models.quantizationOutput; + +public class OneBitScalarQuantizationOutput extends QuantizationOutput { + + public OneBitScalarQuantizationOutput(byte[] quantizedVector) { + super(quantizedVector); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java new file mode 100644 index 000000000..38dfca50e --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java @@ -0,0 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.models.quantizationOutput; + +public abstract class QuantizationOutput { + private final T quantizedVector; + + public QuantizationOutput(T quantizedVector) { + this.quantizedVector = quantizedVector; + } + + public T getQuantizedVector() { + return quantizedVector; + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java new file mode 100644 index 000000000..142b1a59a --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.models.quantizationParams; + +import java.io.Serializable; +import org.opensearch.knn.quantization.enums.QuantizationType; + +public abstract class QuantizationParams implements Serializable { + private QuantizationType quantizationType; + + public QuantizationParams(QuantizationType quantizationType) { + this.quantizationType = quantizationType; + } + + public QuantizationType getQuantizationType() { + return quantizationType; + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java new file mode 100644 index 000000000..e82e8f178 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.models.quantizationParams; + +import org.opensearch.knn.quantization.enums.QuantizationType; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; + +public class SQParams extends QuantizationParams { + private SQTypes sqType; + + public SQParams(SQTypes sqType) { + super(QuantizationType.VALUE_QUANTIZATION); + this.sqType = sqType; + } + public SQTypes getSqType() { + return sqType; + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java new file mode 100644 index 000000000..8c2963e2b --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.models.quantizationState; + + +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.io.ByteArrayInputStream; +import java.io.ObjectInputStream; + +public class OneBitScalarQuantizationState extends QuantizationState { + private float[] mean; + + public OneBitScalarQuantizationState(SQParams quantizationParams, float[] floatArray) { + super(quantizationParams); + this.mean = floatArray; + } + + public float[] getMean() { + return mean; + } + + @Override + public byte[] toByteArray() throws IOException { + byte[] parentBytes = super.toByteArray(); + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(bos); + out.write(parentBytes); + out.writeObject(mean); + out.flush(); + return bos.toByteArray(); + } + + public static OneBitScalarQuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException { + ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + ObjectInputStream in = new ObjectInputStream(bis); + QuantizationState parentState = (QuantizationState) in.readObject(); + float[] floatArray = (float[]) in.readObject(); + return new OneBitScalarQuantizationState((SQParams) parentState.getQuantizationParams(), floatArray); + } +} + diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java new file mode 100644 index 000000000..ba7c2bb48 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.io.ByteArrayInputStream; +import java.io.ObjectInputStream; + +public abstract class QuantizationState implements Serializable { + private QuantizationParams quantizationParams; + + public QuantizationState(QuantizationParams quantizationParams) { + this.quantizationParams = quantizationParams; + } + + public QuantizationParams getQuantizationParams() { + return quantizationParams; + } + + public byte[] toByteArray() throws IOException { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(bos); + out.writeObject(this); + out.flush(); + return bos.toByteArray(); + } + + public static QuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException { + ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + ObjectInputStream in = new ObjectInputStream(bis); + return (QuantizationState) in.readObject(); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/SamplingTrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/SamplingTrainingRequest.java new file mode 100644 index 000000000..ee0de84fb --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/requests/SamplingTrainingRequest.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.models.requests; + +import org.opensearch.knn.quantization.sampler.Sampler; + +import java.util.List; + +public class SamplingTrainingRequest extends TrainingRequest { + private TrainingRequest originalRequest; + private int[] sampledIndices; + + public SamplingTrainingRequest(TrainingRequest originalRequest, Sampler sampler, int sampleSize) { + super(originalRequest.getParams(), originalRequest.getTotalNumberOfVectors()); + this.originalRequest = originalRequest; + this.sampledIndices = sampler.sample(originalRequest.getTotalNumberOfVectors(), sampleSize); + } + + @Override + public T getVector() { + return originalRequest.getVector(); + } + + @Override + public T getVectorByDocId(int docId) { + return originalRequest.getVectorByDocId(docId); + } + + public int[] getSampledIndices() { + return sampledIndices; + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java new file mode 100644 index 000000000..a2830e6d4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.models.requests; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; + +public abstract class TrainingRequest { + private QuantizationParams params; + private int totalNumberOfVectors; + + public TrainingRequest(QuantizationParams params, int totalNumberOfVectors) { + this.params = params; + this.totalNumberOfVectors = totalNumberOfVectors; + } + + public QuantizationParams getParams() { + return params; + } + + public int getTotalNumberOfVectors() { + return totalNumberOfVectors; + } + + public abstract T getVector(); + + public abstract T getVectorByDocId(int docId); +} + diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java new file mode 100644 index 000000000..9f8d04cc9 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -0,0 +1,125 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.quantization.models.quantizationOutput.OneBitScalarQuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.SamplingTrainingRequest; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; + +public class OneBitScalarQuantizer implements Quantizer { + private static final int SAMPLING_SIZE = 25000; + + @Override + public int getSamplingSize() { + return SAMPLING_SIZE; + } + + @Override + public QuantizationState train(TrainingRequest trainingRequest) { + if (!(trainingRequest instanceof SamplingTrainingRequest)) { + throw new IllegalArgumentException("Training request must be of type SamplingTrainingRequest."); + } + + SamplingTrainingRequest samplingRequest = (SamplingTrainingRequest) trainingRequest; + int[] sampledIndices = samplingRequest.getSampledIndices(); + + if (sampledIndices == null || sampledIndices.length == 0) { + throw new IllegalArgumentException("Sampled indices must not be null or empty."); + } + + int totalSamples = sampledIndices.length; + float[] sum = null; + + // Calculate the sum for each dimension based on sampled indices + for (int i = 0; i < totalSamples; i++) { + float[] vector = samplingRequest.getVectorByDocId(sampledIndices[i]); + if (vector == null) { + throw new IllegalArgumentException("Vector at sampled index " + sampledIndices[i] + " is null."); + } + if (sum == null) { + sum = new float[vector.length]; + } else if (sum.length != vector.length) { + throw new IllegalArgumentException("All vectors must have the same dimension."); + } + for (int j = 0; j < vector.length; j++) { + sum[j] += vector[j]; + } + } + if (sum == null) { + throw new IllegalStateException("Sum array should not be null after processing vectors."); + } + // Calculate the mean for each dimension + float[] mean = new float[sum.length]; + for (int j = 0; j < sum.length; j++) { + mean[j] = sum[j] / totalSamples; + } + SQParams params = (SQParams) trainingRequest.getParams(); + if (params == null) { + throw new IllegalArgumentException("Quantization parameters must not be null."); + } + return new OneBitScalarQuantizationState(params, mean); + } + + @Override + public QuantizationOutput quantize(float[] vector, QuantizationState state) { + if (vector == null) { + throw new IllegalArgumentException("Vector to quantize must not be null."); + } + if (!(state instanceof OneBitScalarQuantizationState)) { + throw new IllegalArgumentException("Quantization state must be of type OneBitScalarQuantizationState."); + } + OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState) state; + float[] thresholds = binaryState.getMean(); + if (thresholds == null || thresholds.length != vector.length) { + throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); + } + byte[] quantizedVector = new byte[vector.length]; + for (int i = 0; i < vector.length; i++) { + quantizedVector[i] = (byte) (vector[i] > thresholds[i] ? 1 : 0); + } + return new OneBitScalarQuantizationOutput(packBitsFromBitArray(quantizedVector)); + } + + @Override + public QuantizationOutput quantize(float[] vector) { + if (vector == null) { + throw new IllegalArgumentException("Vector to quantize must not be null."); + } + byte[] quantizedVector = new byte[vector.length]; + for (int i = 0; i < vector.length; i++) { + quantizedVector[i] = (byte) (vector[i] > 0 ? 1 : 0); + } + return new OneBitScalarQuantizationOutput(packBitsFromBitArray(quantizedVector)); + } + + private byte[] packBitsFromBitArray(byte[] bitArray) { + int bitLength = bitArray.length; + int byteLength = (bitLength + 7) / 8; + byte[] packedArray = new byte[byteLength]; + + for (int i = 0; i < bitLength; i++) { + if (bitArray[i] != 0 && bitArray[i] != 1) { + throw new IllegalArgumentException("Array elements must be 0 or 1"); + } + int byteIndex = i / 8; + int bitIndex = 7 - (i % 8); + packedArray[byteIndex] |= (bitArray[i] << bitIndex); + } + + return packedArray; + } +} + diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java new file mode 100644 index 000000000..35fa19cf7 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java @@ -0,0 +1,31 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +public interface Quantizer { + int getSamplingSize(); + + default QuantizationState train(TrainingRequest trainingRequest) { + throw new UnsupportedOperationException("Train method is not supported by this quantizer."); + } + + default QuantizationOutput quantize(T vector, QuantizationState state) { + throw new UnsupportedOperationException("Quantize method with state is not supported by this quantizer."); + } + + default QuantizationOutput quantize(T vector) { + throw new UnsupportedOperationException("Quantize method without state is not supported by this quantizer."); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java new file mode 100644 index 000000000..15fbe6e31 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.sampler; + +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class ReservoirSampler implements Sampler { + private final Random random = new Random(); + + @Override + public int[] sample(int totalNumberOfVectors, int sampleSize) { + if (totalNumberOfVectors <= sampleSize) { + return IntStream.range(0, totalNumberOfVectors).toArray(); + } + return reservoirSampleIndices(totalNumberOfVectors, sampleSize); + } + private int[] reservoirSampleIndices(int numVectors, int sampleSize) { + int[] indices = IntStream.range(0, sampleSize).toArray(); + for (int i = sampleSize; i < numVectors; i++) { + int j = random.nextInt(i + 1); + if (j < sampleSize) { + indices[j] = i; + } + } + Arrays.sort(indices); + return indices; + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java new file mode 100644 index 000000000..dd935a600 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java @@ -0,0 +1,18 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.sampler; + +import java.util.List; + +public interface Sampler { + int[] sample(int totalNumberOfVectors, int sampleSize); +} diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java new file mode 100644 index 000000000..1b39d9846 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java @@ -0,0 +1,28 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.sampler; + +public class SamplingFactory { + public enum SamplerType { + RESERVOIR, + } + + public static Sampler getSampler(SamplerType samplerType) { + switch (samplerType) { + case RESERVOIR: + return new ReservoirSampler(); + // Add more cases for different samplers + default: + throw new IllegalArgumentException("Unsupported sampler type: " + samplerType); + } + } +} diff --git a/src/test/java/org/opensearch/knn/jni/JNICommonsTest.java b/src/test/java/org/opensearch/knn/jni/JNICommonsTests.java similarity index 95% rename from src/test/java/org/opensearch/knn/jni/JNICommonsTest.java rename to src/test/java/org/opensearch/knn/jni/JNICommonsTests.java index bf27458b0..1ea86ef96 100644 --- a/src/test/java/org/opensearch/knn/jni/JNICommonsTest.java +++ b/src/test/java/org/opensearch/knn/jni/JNICommonsTests.java @@ -13,7 +13,7 @@ import org.opensearch.knn.KNNTestCase; -public class JNICommonsTest extends KNNTestCase { +public class JNICommonsTests extends KNNTestCase { public void testStoreVectorData_whenVaildInputThenSuccess() { float[][] data = new float[2][2]; diff --git a/src/test/java/org/opensearch/knn/quantization/QuantizationManagerTests.java b/src/test/java/org/opensearch/knn/quantization/QuantizationManagerTests.java new file mode 100644 index 000000000..7a50c1599 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/QuantizationManagerTests.java @@ -0,0 +1,92 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.quantizer.Quantizer; + + +public class QuantizationManagerTests extends KNNTestCase { + public void testSingletonInstance() { + QuantizationManager instance1 = QuantizationManager.getInstance(); + QuantizationManager instance2 = QuantizationManager.getInstance(); + assertSame(instance1, instance2); + } + + public void testTrain() { + QuantizationManager quantizationManager = QuantizationManager.getInstance(); + float[][] vectors = { + {1.0f, 2.0f, 3.0f}, + {4.0f, 5.0f, 6.0f}, + {7.0f, 8.0f, 9.0f} + }; + + SQParams params = new SQParams(SQTypes.ONE_BIT); + TrainingRequest originalRequest = new TrainingRequest(params, vectors.length) { + @Override + public float[] getVector() { + return null; // Not used in this test + } + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + }; + QuantizationState state = quantizationManager.train(originalRequest); + + assertTrue(state instanceof OneBitScalarQuantizationState); + float[] mean = ((OneBitScalarQuantizationState) state).getMean(); + assertArrayEquals(new float[]{4.0f, 5.0f, 6.0f}, mean, 0.001f); + } + + public void testTrainWithFewVectors() { + QuantizationManager quantizationManager = QuantizationManager.getInstance(); + float[][] vectors = { + {1.0f, 2.0f, 3.0f}, + {4.0f, 5.0f, 6.0f} + }; + + SQParams params = new SQParams(SQTypes.ONE_BIT); + TrainingRequest originalRequest = new TrainingRequest(params, vectors.length) { + @Override + public float[] getVector() { + return null; // Not used in this test + } + + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + }; + + QuantizationState state = quantizationManager.train(originalRequest); + + assertTrue(state instanceof OneBitScalarQuantizationState); + float[] mean = ((OneBitScalarQuantizationState) state).getMean(); + assertArrayEquals(new float[]{2.5f, 3.5f, 4.5f}, mean, 0.001f); + } + + + public void testGetQuantizer() { + QuantizationManager quantizationManager = QuantizationManager.getInstance(); + SQParams params = new SQParams(SQTypes.ONE_BIT); + + Quantizer quantizer = quantizationManager.getQuantizer(params); + + assertTrue(quantizer instanceof Quantizer); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java new file mode 100644 index 000000000..aa4251bfc --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.enums; + +import org.opensearch.knn.KNNTestCase; + +public class QuantizationTypeTests extends KNNTestCase { + + public void testQuantizationTypeValues() { + QuantizationType[] expectedValues = { + QuantizationType.SPACE_QUANTIZATION, + QuantizationType.VALUE_QUANTIZATION + }; + assertArrayEquals(expectedValues, QuantizationType.values()); + } + + public void testQuantizationTypeValueOf() { + assertEquals(QuantizationType.SPACE_QUANTIZATION, QuantizationType.valueOf("SPACE_QUANTIZATION")); + assertEquals(QuantizationType.VALUE_QUANTIZATION, QuantizationType.valueOf("VALUE_QUANTIZATION")); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java b/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java new file mode 100644 index 000000000..5f9aba958 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.enums; + +import org.opensearch.knn.KNNTestCase; + +public class SQTypesTests extends KNNTestCase { + public void testSQTypesValues() { + SQTypes[] expectedValues = { + SQTypes.FP16, + SQTypes.INT8, + SQTypes.INT6, + SQTypes.INT4, + SQTypes.ONE_BIT, + SQTypes.TWO_BIT + }; + assertArrayEquals(expectedValues, SQTypes.values()); + } + + public void testSQTypesValueOf() { + assertEquals(SQTypes.FP16, SQTypes.valueOf("FP16")); + assertEquals(SQTypes.INT8, SQTypes.valueOf("INT8")); + assertEquals(SQTypes.INT6, SQTypes.valueOf("INT6")); + assertEquals(SQTypes.INT4, SQTypes.valueOf("INT4")); + assertEquals(SQTypes.ONE_BIT, SQTypes.valueOf("ONE_BIT")); + assertEquals(SQTypes.TWO_BIT, SQTypes.valueOf("TWO_BIT")); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java new file mode 100644 index 000000000..d5e74ec6a --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.enums; + +import org.opensearch.knn.KNNTestCase; + +public class ValueQuantizationTypeTests extends KNNTestCase { + public void testValueQuantizationTypeValues() { + ValueQuantizationType[] expectedValues = { + ValueQuantizationType.SQ + }; + assertArrayEquals(expectedValues, ValueQuantizationType.values()); + } + + public void testValueQuantizationTypeValueOf() { + assertEquals(ValueQuantizationType.SQ, ValueQuantizationType.valueOf("SQ")); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java new file mode 100644 index 000000000..29ce32066 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.factory; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +public class QuantizerFactoryTests extends KNNTestCase { + public void testGetQuantizer_withSQParams() { + SQParams params = new SQParams(SQTypes.ONE_BIT); + Quantizer quantizer = QuantizerFactory.getQuantizer(params); + assertTrue(quantizer instanceof OneBitScalarQuantizer); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java new file mode 100644 index 000000000..d77255a08 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.factory; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.Quantizer; +import org.junit.BeforeClass; + +public class QuantizerRegistryTests extends KNNTestCase { + @BeforeClass + public static void setup() { + // Register the quantizer for testing + QuantizerRegistry.register(SQParams.class, SQTypes.ONE_BIT.name(), OneBitScalarQuantizer::new); + } + + public void testRegisterAndGetQuantizer() { + SQParams params = new SQParams(SQTypes.ONE_BIT); + Quantizer quantizer = QuantizerRegistry.getQuantizer(params, SQTypes.ONE_BIT.name()); + assertTrue(quantizer instanceof OneBitScalarQuantizer); + } + + public void testGetQuantizer_withUnsupportedTypeIdentifier() { + SQParams params = new SQParams(SQTypes.ONE_BIT); + expectThrows( IllegalArgumentException.class, ()-> QuantizerRegistry.getQuantizer(params, "UNSUPPORTED_TYPE")); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java new file mode 100644 index 000000000..43c960f33 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java @@ -0,0 +1,81 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.SamplingTrainingRequest; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.sampler.ReservoirSampler; + +public class OneBitScalarQuantizerTests extends KNNTestCase { + + public void testTrain() { + float[][] vectors = { + {1.0f, 2.0f, 3.0f}, + {4.0f, 5.0f, 6.0f}, + {7.0f, 8.0f, 9.0f} + }; + + SQParams params = new SQParams(SQTypes.ONE_BIT); + TrainingRequest originalRequest = new TrainingRequest(params, vectors.length) { + @Override + public float[] getVector() { + return null; // Not used in this test + } + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + }; + TrainingRequest trainingRequest = new SamplingTrainingRequest<>( + originalRequest, + new ReservoirSampler(), + vectors.length + ); + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + QuantizationState state = quantizer.train(trainingRequest); + + assertTrue(state instanceof OneBitScalarQuantizationState); + float[] mean = ((OneBitScalarQuantizationState) state).getMean(); + assertArrayEquals(new float[]{4.0f, 5.0f, 6.0f}, mean, 0.001f); + } + + public void testQuantize_withState() { + float[] vector = {3.0f, 6.0f, 9.0f}; + float[] thresholds = {4.0f, 5.0f, 6.0f}; + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(SQTypes.ONE_BIT), thresholds); + + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + QuantizationOutput output = quantizer.quantize(vector, state); + + assertArrayEquals(new byte[]{96}, output.getQuantizedVector()); + } + + public void testQuantize_withoutState() { + float[] vector = {-1.0f, 0.5f, 1.5f}; + + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + QuantizationOutput output = quantizer.quantize(vector); + + assertArrayEquals(new byte[]{96}, output.getQuantizedVector()); + } + + public void testQuantize_withNullVector() { + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + expectThrows( IllegalArgumentException.class, ()-> quantizer.quantize(null)); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java new file mode 100644 index 000000000..b0a290941 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java @@ -0,0 +1,74 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.sampler; + +import org.opensearch.knn.KNNTestCase; + + +public class ReservoirSamplerTests extends KNNTestCase { + + public void testSample() { + Sampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 100; + int sampleSize = 10; + + int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(sampleSize, samples.length); + for (int index : samples) { + assertTrue(index >= 0 && index < totalNumberOfVectors); + } + } + + public void testSample_withFullSampling() { + Sampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 10; + int sampleSize = 10; + + int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(sampleSize, samples.length); + for (int index : samples) { + assertTrue(index >= 0 && index < totalNumberOfVectors); + } + } + + public void testSample_withLessVectors() { + Sampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 5; + int sampleSize = 10; + + int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(totalNumberOfVectors, samples.length); + for (int index : samples) { + assertTrue(index >= 0 && index < totalNumberOfVectors); + } + } + + public void testSample_withZeroVectors() { + Sampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 0; + int sampleSize = 10; + + int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(0, samples.length); + } + + public void testSample_withOneVector() { + Sampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 1; + int sampleSize = 10; + + int[] samples = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(1, samples.length); + assertTrue(samples[0] == 0); + } +} + diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java new file mode 100644 index 000000000..f767710b9 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.quantization.sampler; + +import org.opensearch.knn.KNNTestCase; + +public class SamplingFactoryTests extends KNNTestCase { + public void testGetSampler_withReservoir() { + Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); + assertTrue(sampler instanceof ReservoirSampler); + } + + public void testGetSampler_withUnsupportedType() { + expectThrows( NullPointerException.class, ()-> SamplingFactory.getSampler(null)); // This should throw an exception + } +}