From bcdd4302674ec810ef6e6bf57fd37f984761c7bc Mon Sep 17 00:00:00 2001 From: VIKASH TIWARI Date: Fri, 2 Aug 2024 22:00:13 -0700 Subject: [PATCH] Quantization Framework Implementation with 1bit and MultiBit Binary Quantizer --- CHANGELOG.md | 3 +- .../quantization/enums/QuantizationType.java | 34 ++++ .../knn/quantization/enums/SQTypes.java | 60 ++++++ .../enums/ValueQuantizationType.java | 20 ++ .../factory/QuantizerFactory.java | 54 ++++++ .../factory/QuantizerRegistrar.java | 54 ++++++ .../factory/QuantizerRegistry.java | 79 ++++++++ .../BinaryQuantizationOutput.java | 31 ++++ .../QuantizationOutput.java | 22 +++ .../QuantizationParams.java | 39 ++++ .../models/quantizationParams/SQParams.java | 83 +++++++++ .../DefaultQuantizationState.java | 67 +++++++ .../MultiBitScalarQuantizationState.java | 60 ++++++ .../OneBitScalarQuantizationState.java | 60 ++++++ .../quantizationState/QuantizationState.java | 33 ++++ .../models/requests/TrainingRequest.java | 74 ++++++++ .../quantizer/MultiBitScalarQuantizer.java | 175 ++++++++++++++++++ .../quantizer/OneBitScalarQuantizer.java | 128 +++++++++++++ .../knn/quantization/quantizer/Quantizer.java | 40 ++++ .../sampler/ReservoirSampler.java | 85 +++++++++ .../knn/quantization/sampler/Sampler.java | 10 + .../quantization/sampler/SamplingFactory.java | 46 +++++ .../quantization/util/BitPackingUtils.java | 60 ++++++ .../util/QuantizationStateSerializer.java | 109 +++++++++++ .../quantization/util/QuantizerHelper.java | 112 +++++++++++ .../enums/QuantizationTypeTests.java | 24 +++ .../knn/quantization/enums/SQTypesTests.java | 36 ++++ .../enums/ValueQuantizationTypeTests.java | 21 +++ .../factory/QuantizerFactoryTests.java | 99 ++++++++++ .../factory/QuantizerRegistryTests.java | 55 ++++++ .../QuantizationStateSerializerTests.java | 44 +++++ .../QuantizationStateTests.java | 64 +++++++ .../MultiBitScalarQuantizerTests.java | 137 ++++++++++++++ .../quantizer/OneBitScalarQuantizerTests.java | 137 ++++++++++++++ .../sampler/ReservoirSamplerTests.java | 103 +++++++++++ .../sampler/SamplingFactoryTests.java | 19 ++ .../util/BitPackingUtilsTests.java | 86 +++++++++ 37 files changed, 2362 insertions(+), 1 deletion(-) create mode 100644 src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java create mode 100644 src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java create mode 100644 src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java create mode 100644 src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java create mode 100644 src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java create mode 100644 src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java create mode 100644 src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java create mode 100644 src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java create mode 100644 src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java create mode 100644 src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java create mode 100644 src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java create mode 100644 src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java create mode 100644 src/main/java/org/opensearch/knn/quantization/util/BitPackingUtils.java create mode 100644 src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java create mode 100644 src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java create mode 100644 src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 76eeb1e447..dcb714a58d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,4 +27,5 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Clean up parsing for query [#1824](https://github.com/opensearch-project/k-NN/pull/1824) * Refactor engine package structure [#1913](https://github.com/opensearch-project/k-NN/pull/1913) * Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920) -* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) \ No newline at end of file +* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) +* Quantization Framework For Disk Optimized Vector Search and Implementation of Binary 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889) diff --git a/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java new file mode 100644 index 0000000000..254fb1c423 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +/** + * The QuantizationType enum represents the different types of quantization + * that can be applied in the KNN. + * + * + */ +public enum QuantizationType { + /** + * Represents space quantization, typically involving dimensionality reduction + * or space partitioning techniques. + */ + SPACE_QUANTIZATION, + + /** + * Represents value quantization, typically involving the conversion of continuous + * values into discrete ones. + */ + VALUE_QUANTIZATION, +} diff --git a/src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java b/src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java new file mode 100644 index 0000000000..3fdd7f25c2 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/enums/SQTypes.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +/** + * The SQTypes enum defines the various scalar quantization types that can be used + * in the KNN for vector quantization. + * Each type corresponds to a different bit-width representation of the quantized values. + */ +public enum SQTypes { + /** + * FP16 quantization uses 16-bit floating-point representation. + */ + FP16, + + /** + * FP8 quantization uses 8-bit floating-point representation. + */ + FP8, + + /** + * INT8 quantization uses 8-bit integer representation. + */ + INT8, + + /** + * INT6 quantization uses 6-bit integer representation. + */ + INT6, + + /** + * INT4 quantization uses 4-bit integer representation. + */ + INT4, + + /** + * ONE_BIT quantization uses a single bit per coordinate. + */ + ONE_BIT, + + /** + * TWO_BIT quantization uses two bits per coordinate. + */ + TWO_BIT, + + /** + * FOUR_BIT quantization uses four bits per coordinate. + */ + FOUR_BIT, + + /** + * UNSUPPORTED_TYPE is used to denote quantization types that are not supported. + * This can be used as a placeholder or default value. + */ + UNSUPPORTED_TYPE +} + diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java new file mode 100644 index 0000000000..ffc2ee29b0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +/** + * The ValueQuantizationType enum defines the types of value quantization techniques + * that can be applied in the KNN. + */ +public enum ValueQuantizationType { + /** + * SQ (Scalar Quantization) represents a method where each coordinate of the vector is quantized + * independently. + */ + SQ +} + + diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java new file mode 100644 index 0000000000..c16fdae8a1 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * The QuantizerFactory class is responsible for creating instances of {@link Quantizer} + * based on the provided {@link QuantizationParams}. It uses a registry to look up the + * appropriate quantizer implementation for the given quantization parameters. + */ +public final class QuantizerFactory { + private static final AtomicBoolean isRegistered = new AtomicBoolean(false); + + // Private constructor to prevent instantiation + private QuantizerFactory() {} + + /** + * Ensures that default quantizers are registered. + */ + private static void ensureRegistered() { + if (!isRegistered.get()) { + synchronized (QuantizerFactory.class) { + if (!isRegistered.get()) { + QuantizerRegistrar.registerDefaultQuantizers(); + isRegistered.set(true); + } + } + } + } + + /** + * Retrieves a quantizer instance based on the provided quantization parameters. + * + * @param params the quantization parameters used to determine the appropriate quantizer + * @param

the type of quantization parameters, extending {@link QuantizationParams} + * @param the type of the quantized output + * @return an instance of {@link Quantizer} corresponding to the provided parameters + */ + public static

Quantizer getQuantizer(final P params) { + if (params == null) { + throw new IllegalArgumentException("Quantization parameters must not be null."); + } + // Lazy Registration instead of static block as class level; + ensureRegistered(); + return QuantizerRegistry.getQuantizer(params); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java new file mode 100644 index 0000000000..d932686f12 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import org.opensearch.knn.quantization.enums.QuantizationType; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; + +/** + * The QuantizerRegistrar class is responsible for registering default quantizers. + * This class ensures that the registration happens only once in a thread-safe manner. + */ +final class QuantizerRegistrar { + + // Private constructor to prevent instantiation + private QuantizerRegistrar() { + } + + /** + * Registers default quantizers if not already registered. + *

+ * This method is synchronized to ensure that registration occurs only once, + * even in a multi-threaded environment. + *

+ */ + public static synchronized void registerDefaultQuantizers() { + // Register OneBitScalarQuantizer for SQParams with VALUE_QUANTIZATION and SQTypes.ONE_BIT + QuantizerRegistry.register( + SQParams.class, + QuantizationType.VALUE_QUANTIZATION, + SQTypes.ONE_BIT, + OneBitScalarQuantizer::new + ); + // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 2 + QuantizerRegistry.register( + SQParams.class, + QuantizationType.VALUE_QUANTIZATION, + SQTypes.TWO_BIT, + () -> new MultiBitScalarQuantizer(2) + ); + // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 4 + QuantizerRegistry.register( + SQParams.class, + QuantizationType.VALUE_QUANTIZATION, + SQTypes.FOUR_BIT, + () -> new MultiBitScalarQuantizer(4) + ); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java new file mode 100644 index 0000000000..41c932bee4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import org.opensearch.knn.quantization.enums.QuantizationType; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Supplier; + +/** + * The QuantizerRegistry class is responsible for managing the registration and retrieval + * of quantizer instances. Quantizers are registered with specific quantization parameters + * and type identifiers, allowing for efficient lookup and instantiation. + */ +final class QuantizerRegistry { + + // Private constructor to prevent instantiation + private QuantizerRegistry() {} + + //ConcurrentHashMap for thread-safe access + private static final Map>> registry = new ConcurrentHashMap<>(); + + /** + * Registers a quantizer with the registry. + * + * @param paramClass the class of the quantization parameters + * @param quantizationType the quantization type (e.g., VALUE_QUANTIZATION) + * @param sqType the specific quantization subtype (e.g., ONE_BIT, TWO_BIT) + * @param quantizerSupplier a supplier that provides instances of the quantizer + * @param

the type of quantization parameters + */ + public static

void register( final Class

paramClass, + final QuantizationType quantizationType, + final SQTypes sqType, + final Supplier> quantizerSupplier) { + String identifier = createIdentifier(quantizationType, sqType); + // Ensure that the quantizer for this identifier is registered only once + registry.computeIfAbsent(identifier, key -> quantizerSupplier); + } + + /** + * Retrieves a quantizer instance based on the provided quantization parameters. + * + * @param params the quantization parameters used to determine the appropriate quantizer + * @param

the type of quantization parameters + * @param the type of the quantized output + * @return an instance of {@link Quantizer} corresponding to the provided parameters + * @throws IllegalArgumentException if no quantizer is registered for the given parameters + */ + public static

Quantizer getQuantizer(final P params) { + String identifier = params.getTypeIdentifier(); + Supplier> supplier = registry.get(identifier); + if (supplier == null) { + throw new IllegalArgumentException("No quantizer registered for type identifier: " + identifier + + ". Available quantizers: " + registry.keySet()); + } + @SuppressWarnings("unchecked") + Quantizer quantizer = (Quantizer) supplier.get(); + return quantizer; + } + + /** + * Creates a unique identifier for the quantizer based on the quantization type and specific quantization subtype. + * + * @param quantizationType the quantization type + * @param sqType the specific quantization subtype + * @return a string identifier + */ + private static String createIdentifier(final QuantizationType quantizationType, final SQTypes sqType) { + return quantizationType.name() + "_" + sqType.name(); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java new file mode 100644 index 0000000000..f64a2f0eef --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationOutput; + +/** + * The BinaryQuantizationOutput class represents the output of a quantization process in binary format. + * It implements the QuantizationOutput interface to handle byte arrays specifically. + */ +public class BinaryQuantizationOutput implements QuantizationOutput { + private final byte[] quantizedVector; + + /** + * Constructs a BinaryQuantizationOutput instance with the specified quantized vector. + * + * @param quantizedVector the quantized vector represented as a byte array. + */ + public BinaryQuantizationOutput(final byte[] quantizedVector) { + if (quantizedVector == null) { + throw new IllegalArgumentException("Quantized vector cannot be null"); + } + this.quantizedVector = quantizedVector; + } + + @Override + public byte[] getQuantizedVector() { + return quantizedVector; + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java new file mode 100644 index 0000000000..d2fc7dce06 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationOutput; + +/** + * The QuantizationOutput interface defines the contract for quantization output data. + * + * @param The type of the quantized data. + */ +public interface QuantizationOutput { + /** + * Returns the quantized vector. + * + * @return the quantized data. + */ + T getQuantizedVector(); +} + + diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java new file mode 100644 index 0000000000..db4677985f --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationParams; + +import org.opensearch.knn.quantization.enums.QuantizationType; + +import java.io.Serializable; + +/** + * Interface for quantization parameters. + * This interface defines the basic contract for all quantization parameter types. + * It provides methods to retrieve the quantization type and a unique type identifier. + * Implementations of this interface are expected to provide specific configurations + * for various quantization strategies. + */ +public interface QuantizationParams extends Serializable{ + + /** + * Gets the quantization type associated with the parameters. + * The quantization type defines the overall strategy or method used + * for quantization, such as VALUE_QUANTIZATION or SPACE_QUANTIZATION. + * + * @return the {@link QuantizationType} indicating the quantization method. + */ + QuantizationType getQuantizationType(); + + /** + * Provides a unique identifier for the quantization parameters. + * This identifier is typically a combination of the quantization type + * and additional specifics, and it serves to distinguish between different + * configurations or modes of quantization. + * + * @return a string representing the unique type identifier. + */ + String getTypeIdentifier(); +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java new file mode 100644 index 0000000000..22553d70aa --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationParams; + +import org.opensearch.knn.quantization.enums.QuantizationType; +import org.opensearch.knn.quantization.enums.SQTypes; + +import java.util.Objects; + +/** + * The SQParams class represents the parameters specific to scalar quantization (SQ). + * This class implements the QuantizationParams interface and includes the type of scalar quantization. + */ +public class SQParams implements QuantizationParams { + private final SQTypes sqType; + + /** + * Constructs an SQParams instance with the specified scalar quantization type. + * + * @param sqType The specific type of scalar quantization (e.g., ONE_BIT, TWO_BIT, FOUR_BIT). + */ + public SQParams(final SQTypes sqType) { + this.sqType = sqType; + } + + /** + * Returns the quantization type associated with these parameters. + * + * @return The quantization type, always VALUE_QUANTIZATION for SQParams. + */ + @Override + public QuantizationType getQuantizationType() { + return QuantizationType.VALUE_QUANTIZATION; + } + + /** + * Returns the scalar quantization type. + * + * @return The specific scalar quantization type. + */ + public SQTypes getSqType() { + return sqType; + } + + /** + * Provides a unique type identifier for the SQParams, combining the quantization type and SQ type. + * This identifier is useful for distinguishing between different configurations of scalar quantization parameters. + * + * @return A string representing the unique type identifier. + */ + @Override + public String getTypeIdentifier() { + return getQuantizationType().name() + "_" + sqType.name(); + } + + /** + * Compares this object to the specified object. The result is true if and only if the argument is not null and is + * an SQParams object that represents the same scalar quantization type. + * + * @param o The object to compare this SQParams against. + * @return true if the given object represents an SQParams equivalent to this instance, false otherwise. + */ + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SQParams sqParams = (SQParams) o; + return sqType == sqParams.sqType; + } + + /** + * Returns a hash code value for this SQParams instance. + * + * @return A hash code value for this SQParams instance. + */ + @Override + public int hashCode() { + return Objects.hash(sqType); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java new file mode 100644 index 0000000000..2599ece9aa --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.util.QuantizationStateSerializer; + +import java.io.IOException; + +/** + * DefaultQuantizationState is used as a fallback state when no training is required or if training fails. + * It can be utilized by any quantizer to represent a default state. + */ +public class DefaultQuantizationState implements QuantizationState { + + private final QuantizationParams params; + + /** + * Constructs a DefaultQuantizationState with the given quantization parameters. + * + * @param params the quantization parameters. + */ + public DefaultQuantizationState(final QuantizationParams params) { + this.params = params; + } + + /** + * Returns the quantization parameters associated with this state. + * + * @return the quantization parameters. + */ + @Override + public QuantizationParams getQuantizationParams() { + return params; + } + + /** + * Serializes the quantization state to a byte array. + * + * @return a byte array representing the serialized state. + * @throws IOException if an I/O error occurs during serialization. + */ + @Override + public byte[] toByteArray() throws IOException { + return QuantizationStateSerializer.serialize(this, null); + } + + /** + * Deserializes a DefaultQuantizationState from a byte array. + * + * @param bytes the byte array containing the serialized state. + * @return the deserialized DefaultQuantizationState. + * @throws IOException if an I/O error occurs during deserialization. + * @throws ClassNotFoundException if the class of the serialized object cannot be found. + */ + public static DefaultQuantizationState fromByteArray(byte[] bytes) throws IOException, ClassNotFoundException { + return (DefaultQuantizationState) + QuantizationStateSerializer.deserialize(bytes, (parentParams, specificData) -> + new DefaultQuantizationState( + (SQParams) parentParams) + ); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java new file mode 100644 index 0000000000..6edfc35b1b --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.util.QuantizationStateSerializer; + +import java.io.IOException; + +/** + * MultiBitScalarQuantizationState represents the state of multi-bit scalar quantization, + * including the thresholds used for quantization. + */ +public final class MultiBitScalarQuantizationState implements QuantizationState { + private final SQParams quantizationParams; + private final float[][] thresholds; + + /** + * Constructs a MultiBitScalarQuantizationState with the given quantization parameters and thresholds. + * + * @param quantizationParams the scalar quantization parameters. + * @param thresholds the threshold values for multi-bit quantization, organized as a 2D array + * where each row corresponds to a different bit level. + */ + public MultiBitScalarQuantizationState(final SQParams quantizationParams, final float[][] thresholds) { + this.quantizationParams = quantizationParams; + this.thresholds = thresholds; + } + + @Override + public SQParams getQuantizationParams() { + return quantizationParams; + } + + /** + * Returns the thresholds used in the quantization process. + * + * @return a 2D array of threshold values. + */ + public float[][] getThresholds() { + return thresholds; + } + + @Override + public byte[] toByteArray() throws IOException { + return QuantizationStateSerializer.serialize(this, thresholds); + } + + public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { + return (MultiBitScalarQuantizationState) + QuantizationStateSerializer.deserialize(bytes, (parentParams, thresholds) -> + new MultiBitScalarQuantizationState( + (SQParams) parentParams, + (float[][]) thresholds) + ); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java new file mode 100644 index 0000000000..6227a1168e --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.util.QuantizationStateSerializer; + +import java.io.IOException; + +/** + * OneBitScalarQuantizationState represents the state of one-bit scalar quantization, + * including the mean values used for quantization. + */ +public final class OneBitScalarQuantizationState implements QuantizationState { + private final SQParams quantizationParams; + private final float[] mean; + + /** + * Constructs a OneBitScalarQuantizationState with the given quantization parameters and mean values. + * + * @param quantizationParams the scalar quantization parameters. + * @param mean the mean values for each dimension. + */ + public OneBitScalarQuantizationState(final SQParams quantizationParams, final float[] mean) { + this.quantizationParams = quantizationParams; + this.mean = mean; + } + + @Override + public SQParams getQuantizationParams() { + return quantizationParams; + } + + /** + * Returns the mean values used in the quantization process. + * + * @return an array of mean values. + */ + public float[] getMean() { + return mean; + } + + @Override + public byte[] toByteArray() throws IOException { + return QuantizationStateSerializer.serialize(this, mean); + } + + public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { + return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize( + bytes, + (parentParams, mean) -> + new OneBitScalarQuantizationState( + (SQParams) parentParams, + (float[]) mean) + ); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java new file mode 100644 index 0000000000..6d19e385c3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; + +import java.io.IOException; +import java.io.Serializable; + +/** + * QuantizationState interface represents the state of a quantization process, including the parameters used. + * This interface provides methods for serializing and deserializing the state. + */ +public interface QuantizationState extends Serializable { + /** + * Returns the quantization parameters associated with this state. + * + * @return the quantization parameters. + */ + QuantizationParams getQuantizationParams(); + + /** + * Serializes the quantization state to a byte array. + * + * @return a byte array representing the serialized state. + * @throws IOException if an I/O error occurs during serialization. + */ + byte[] toByteArray() throws IOException; +} + diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java new file mode 100644 index 0000000000..b8d33b5957 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.requests; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; + +/** + * TrainingRequest represents a request for training a quantizer. + * + * @param the type of vectors to be trained. + */ +public abstract class TrainingRequest { + private final QuantizationParams params; + private final int totalNumberOfVectors; + private int[] sampledIndices; + + /** + * Constructs a TrainingRequest with the given parameters and total number of vectors. + * + * @param params the quantization parameters. + * @param totalNumberOfVectors the total number of vectors. + */ + protected TrainingRequest(QuantizationParams params, int totalNumberOfVectors) { + this.params = params; + this.totalNumberOfVectors = totalNumberOfVectors; + } + + /** + * Returns the quantization parameters. + * + * @return the quantization parameters. + */ + public QuantizationParams getParams() { + return params; + } + + /** + * Returns the total number of vectors. + * + * @return the total number of vectors. + */ + public int getTotalNumberOfVectors() { + return totalNumberOfVectors; + } + + /** + * Sets the sampled indices for this training request. + * + * @param sampledIndices the sampled indices. + */ + public void setSampledIndices(int[] sampledIndices) { + this.sampledIndices = sampledIndices; + } + + /** + * Returns the sampled indices for this training request. + * + * @return the sampled indices. + */ + public int[] getSampledIndices() { + return sampledIndices; + } + + /** + * Returns the vector corresponding to the specified document ID. + * + * @param docId the document ID. + * @return the vector corresponding to the specified document ID. + */ + public abstract T getVectorByDocId(int docId); +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java new file mode 100644 index 0000000000..13dded2049 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -0,0 +1,175 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplingFactory; +import org.opensearch.knn.quantization.util.BitPackingUtils; +import org.opensearch.knn.quantization.util.QuantizerHelper; + +import java.util.ArrayList; +import java.util.List; + +/** + * MultiBitScalarQuantizer is responsible for quantizing vectors into multi-bit representations per dimension. + * It supports multiple bits per coordinate, allowing for finer granularity in quantization. + */ +public class MultiBitScalarQuantizer implements Quantizer { + private final int bitsPerCoordinate; // Number of bits used to quantize each dimension + private final int samplingSize; // Sampling size for training + private final Sampler sampler; // Sampler for training + private static final boolean IS_TRAINING_REQUIRED = true; + + /** + * Constructs a MultiBitScalarQuantizer with a specified number of bits per coordinate. + * + * @param bitsPerCoordinate the number of bits used per coordinate for quantization. + */ + public MultiBitScalarQuantizer(final int bitsPerCoordinate) { + this( + bitsPerCoordinate, + 25000, + SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR) + ); + } + + /** + * Constructs a MultiBitScalarQuantizer with a specified number of bits per coordinate and sampling size. + * + * @param bitsPerCoordinate the number of bits used per coordinate for quantization. + * @param samplingSize the number of samples to use for training. + * @param sampler the sampler to use for training. + */ + public MultiBitScalarQuantizer( + final int bitsPerCoordinate, + final int samplingSize, + final Sampler sampler + ) { + if (bitsPerCoordinate < 2) { + throw new IllegalArgumentException("bitsPerCoordinate must be greater than or equal to 2 for multibit quantizer."); + } + this.bitsPerCoordinate = bitsPerCoordinate; + this.samplingSize = samplingSize; + this.sampler = sampler; + } + + /** + * Trains the quantizer based on the provided training request, which should be of type SamplingTrainingRequest. + * The training process calculates the mean and standard deviation for each dimension and then determines + * threshold values for quantization based on these statistics. + * + * @param trainingRequest the request containing the data and parameters for training. + * @return a MultiBitScalarQuantizationState containing the computed thresholds. + */ + @Override + public QuantizationState train(final TrainingRequest trainingRequest) { + if (!IS_TRAINING_REQUIRED) { + return new DefaultQuantizationState(trainingRequest.getParams()); + } + SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest); + int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + + int dimension = trainingRequest.getVectorByDocId(sampledIndices[0]).length; + float[] sumAndMean = new float[dimension]; + float[] sumSqAndStdDev = new float[dimension]; + // Calculate sum, mean, and standard deviation in one pass + QuantizerHelper.calculateSumMeanAndStdDev(trainingRequest, sampledIndices, sumAndMean, sumSqAndStdDev); + float[][] thresholds = calculateThresholds(sumAndMean, sumSqAndStdDev, dimension); + return new MultiBitScalarQuantizationState(params, thresholds); + } + + /** + * Quantizes the provided vector using the provided quantization state, producing a quantized output. + * The vector is quantized based on the thresholds in the quantization state. + * + * @param vector the vector to quantize. + * @param state the quantization state containing threshold information. + * @return a BinaryQuantizationOutput containing the quantized data. + */ + @Override + public QuantizationOutput quantize(final float[] vector, final QuantizationState state) { + if (vector == null) { + throw new IllegalArgumentException("Vector to quantize must not be null."); + } + validateState(state); + MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) state; + float[][] thresholds = multiBitState.getThresholds(); + if (thresholds == null || thresholds[0].length != vector.length) { + throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); + } + + List bitArrays = new ArrayList<>(); + for (int i = 0; i < bitsPerCoordinate; i++) { + byte[] bitArray = new byte[vector.length]; + for (int j = 0; j < vector.length; j++) { + bitArray[j] = (byte) (vector[j] > thresholds[i][j] ? 1 : 0); + } + bitArrays.add(bitArray); + } + + return new BinaryQuantizationOutput(BitPackingUtils.packBits(bitArrays)); + } + /** + * Calculates the thresholds for quantization based on mean and standard deviation. + * + * @param mean the mean for each dimension. + * @param stdDev the standard deviation for each dimension. + * @param dimension the number of dimensions in the vectors. + * @return the thresholds for quantization. + */ + private float[][] calculateThresholds( + final float[] mean, + final float[] stdDev, + final int dimension + ) { + float[][] thresholds = new float[bitsPerCoordinate][dimension]; + float coef = bitsPerCoordinate + 1; + for (int i = 0; i < bitsPerCoordinate; i++) { + float iCoef = -1 + 2 * (i + 1) / coef; + for (int j = 0; j < dimension; j++) { + thresholds[i][j] = mean[j] + iCoef * stdDev[j]; + } + } + return thresholds; + } + + /** + * Quantizes a given float vector into a byte array representation. + * + *

Note: This method currently throws an {@link UnsupportedOperationException} as the + * quantization state is required for the OneBitScalar Quantizer. + * + * @param vector the float vector to be quantized + * @return a {@link QuantizationOutput} containing the byte array representation of the quantized vector + * @throws UnsupportedOperationException if the quantization state is not available + */ + private QuantizationOutput quantize(final float[] vector) { + throw new UnsupportedOperationException("Quantization state is required for OneBitScalar Quantizer."); + } + + /** + * Validates the quantization state to ensure it is of the expected type. + * + * @param state the quantization state to validate. + * @throws IllegalArgumentException if the state is not an instance of MultiBitScalarQuantizationState. + */ + private void validateState(final QuantizationState state) { + if (state instanceof DefaultQuantizationState) { + throw new UnsupportedOperationException("Quantization state is required for MultiBitScalar Quantizer."); + } + if (!(state instanceof MultiBitScalarQuantizationState)) { + throw new IllegalArgumentException("Quantization state must be of type MultiBitScalarQuantizationState."); + } + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java new file mode 100644 index 0000000000..2cbb7a2341 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplingFactory; +import org.opensearch.knn.quantization.util.BitPackingUtils; +import org.opensearch.knn.quantization.util.QuantizerHelper; + +import java.util.Collections; + +/** + * OneBitScalarQuantizer is responsible for quantizing vectors using a single bit per dimension. + * It computes the mean of each dimension during training and then uses these means as thresholds + * for quantizing the vectors. + */ +public class OneBitScalarQuantizer implements Quantizer { + private final int samplingSize; // Sampling size for training + private static final boolean IS_TRAINING_REQUIRED = true; + private final Sampler sampler; // Sampler for training + + /** + * Constructs a OneBitScalarQuantizer with a default sampling size of 25000. + */ + public OneBitScalarQuantizer() { + this( + 25000, + SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR) + ); + } + /** + * Constructs a OneBitScalarQuantizer with a specified sampling size. + * + * @param samplingSize the number of samples to use for training. + */ + public OneBitScalarQuantizer( + final int samplingSize, + final Sampler sampler + ) { + + this.samplingSize = samplingSize; + this.sampler = sampler;; + } + + /** + * Trains the quantizer by calculating the mean of each dimension from the sampled vectors. + * These means are used as thresholds in the quantization process. + * + * @param trainingRequest the request containing the data and parameters for training. + * @return a OneBitScalarQuantizationState containing the calculated means. + */ + @Override + public QuantizationState train(final TrainingRequest trainingRequest) { + if (!IS_TRAINING_REQUIRED) { + return new DefaultQuantizationState(trainingRequest.getParams()); + } + SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest); + int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + float[] mean = QuantizerHelper.calculateMean(trainingRequest, sampledIndices); + return new OneBitScalarQuantizationState(params, mean); + } + + /** + * Quantizes the provided vector using the given quantization state. + * It compares each dimension of the vector against the corresponding mean (threshold) to determine the quantized value. + * + * @param vector the vector to quantize. + * @param state the quantization state containing the means for each dimension. + * @return a BinaryQuantizationOutput containing the quantized data. + */ + @Override + public QuantizationOutput quantize(final float[] vector, final QuantizationState state) { + if (vector == null) { + throw new IllegalArgumentException("Vector to quantize must not be null."); + } + validateState(state); + OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState) state; + float[] thresholds = binaryState.getMean(); + if (thresholds == null || thresholds.length != vector.length) { + throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); + } + byte[] quantizedVector = new byte[vector.length]; + for (int i = 0; i < vector.length; i++) { + quantizedVector[i] = (byte) (vector[i] > thresholds[i] ? 1 : 0); + } + return new BinaryQuantizationOutput(BitPackingUtils.packBits(Collections.singletonList(quantizedVector))); + } + + /** + * Quantizes a given float vector into a byte array representation. + * + *

Note: This method currently throws an {@link UnsupportedOperationException} as the + * quantization state is required for the OneBitScalar Quantizer. + * + * @param vector the float vector to be quantized + * @return a {@link QuantizationOutput} containing the byte array representation of the quantized vector + * @throws UnsupportedOperationException if the quantization state is not available + */ + private QuantizationOutput quantize(final float[] vector) { + throw new UnsupportedOperationException("Quantization state is required for OneBitScalar Quantizer."); + } + + /** + * Validates the quantization state to ensure it is of the expected type. + * + * @param state the quantization state to validate. + * @throws IllegalArgumentException if the state is not an instance of OneBitScalarQuantizationState. + */ + private void validateState(final QuantizationState state) { + if (state instanceof DefaultQuantizationState) { + throw new UnsupportedOperationException("Quantization state is required for OneBitScalar Quantizer."); + } + if (!(state instanceof OneBitScalarQuantizationState)) { + throw new IllegalArgumentException("Quantization state must be of type OneBitScalarQuantizationState."); + } + } +} + diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java new file mode 100644 index 0000000000..ef26612edb --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; + +/** + * The Quantizer interface defines the methods required for training and quantizing vectors + * in the context of K-Nearest Neighbors (KNN) and similar machine learning tasks. + * It supports training to determine quantization parameters and quantizing data vectors + * based on these parameters. + * + * @param The type of the vector or data to be quantized. + * @param The type of the quantized output, typically a compressed or encoded representation. + */ +public interface Quantizer { + + /** + * Trains the quantizer based on the provided training request. The training process typically + * involves learning parameters that can be used to quantize vectors. + * + * @param trainingRequest the request containing data and parameters for training. + * @return a QuantizationState containing the learned parameters. + */ + QuantizationState train(TrainingRequest trainingRequest); + + /** + * Quantizes the provided vector using the specified quantization state. + * + * @param vector the vector to quantize. + * @param state the quantization state containing parameters for quantization. + * @return a QuantizationOutput containing the quantized representation of the vector. + */ + QuantizationOutput quantize(T vector, QuantizationState state); +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java new file mode 100644 index 0000000000..97559a3b4d --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +import java.util.Arrays; +import java.util.Random; +import java.util.stream.IntStream; + +/** + * ReservoirSampler implements the Sampler interface and provides a method for sampling + * a specified number of indices from a total number of vectors using the reservoir sampling algorithm. + * This algorithm is particularly useful for randomly sampling a subset of data from a larger set + * when the total size of the dataset is unknown or very large. + */ +final class ReservoirSampler implements Sampler { + + private final Random random; + + /** + * Constructs a ReservoirSampler with a new Random instance. + */ + public ReservoirSampler() { + this(new Random()); + } + + /** + * Constructs a ReservoirSampler with a specified random seed for reproducibility. + * + * @param seed the seed for the random number generator. + */ + public ReservoirSampler(final long seed) { + this(new Random(seed)); + } + + /** + * Constructs a ReservoirSampler with a specified Random instance. + * + * @param random the Random instance for generating random numbers. + */ + public ReservoirSampler(final Random random) { + this.random = random; + } + + /** + * Samples indices from the range [0, totalNumberOfVectors). + * If the total number of vectors is less than or equal to the sample size, it returns all indices. + * Otherwise, it uses the reservoir sampling algorithm to select a random subset. + * + * @param totalNumberOfVectors the total number of vectors to sample from. + * @param sampleSize the number of indices to sample. + * @return an array of sampled indices. + */ + @Override + public int[] sample(final int totalNumberOfVectors, final int sampleSize) { + if (totalNumberOfVectors <= sampleSize) { + return IntStream.range(0, totalNumberOfVectors).toArray(); + } + return reservoirSampleIndices(totalNumberOfVectors, sampleSize); + } + + /** + * Applies the reservoir sampling algorithm to select a random sample of indices. + * This method ensures that each index in the range [0, numVectors) has an equal probability + * of being included in the sample. + * + * @param numVectors the total number of vectors. + * @param sampleSize the number of indices to sample. + * @return an array of sampled indices. + */ + private int[] reservoirSampleIndices(final int numVectors, final int sampleSize) { + int[] indices = IntStream.range(0, sampleSize).toArray(); + for (int i = sampleSize; i < numVectors; i++) { + int j = random.nextInt(i + 1); + if (j < sampleSize) { + indices[j] = i; + } + } + // If sorted indices are required, uncomment the following line + // Arrays.sort(indices); + return indices; + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java new file mode 100644 index 0000000000..9021073b4e --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java @@ -0,0 +1,10 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +public interface Sampler { + int[] sample(int totalNumberOfVectors, int sampleSize); +} diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java new file mode 100644 index 0000000000..be228fe6f9 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +/** + * SamplingFactory is a factory class for creating instances of Sampler. + * It uses the factory design pattern to encapsulate the creation logic for different types of samplers. + */ +public final class SamplingFactory { + + /** + * Private constructor to prevent instantiation of this class. + * The class is not meant to be instantiated, as it provides static methods only. + */ + private SamplingFactory() { + + } + + /** + * SamplerType is an enumeration of the different types of samplers that can be created by the factory. + */ + public enum SamplerType { + RESERVOIR, // Represents a reservoir sampling strategy + // Add more enum values here for additional sampler types + } + + /** + * Creates and returns a Sampler instance based on the specified SamplerType. + * + * @param samplerType the type of sampler to create. + * @return a Sampler instance. + * @throws IllegalArgumentException if the sampler type is not supported. + */ + public static Sampler getSampler(final SamplerType samplerType) { + switch (samplerType) { + case RESERVOIR: + return new ReservoirSampler(); + // Add more cases for different samplers here + default: + throw new IllegalArgumentException("Unsupported sampler type: " + samplerType); + } + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/util/BitPackingUtils.java b/src/main/java/org/opensearch/knn/quantization/util/BitPackingUtils.java new file mode 100644 index 0000000000..bf8588947d --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/util/BitPackingUtils.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.quantization.util; + +import lombok.experimental.UtilityClass; + +import java.util.List; + +/** + * Utility class for bit packing operations. + * Provides methods for packing arrays of bits into byte arrays for efficient storage or transmission. + */ +@UtilityClass +public class BitPackingUtils { + + /** + * Packs the list of bit arrays into a single byte array. + * Each byte in the resulting array contains up to 8 bits from the bit arrays, packed from left to right. + * + * @param bitArrays the list of bit arrays to be packed. Each bit array should contain only 0s and 1s. + * @return a byte array containing the packed bits. + * @throws IllegalArgumentException if the bitArrays list is empty, if any bit array is null, or if bit arrays have inconsistent lengths. + */ + public static byte[] packBits(List bitArrays) { + if (bitArrays.isEmpty()) { + throw new IllegalArgumentException("The list of bit arrays cannot be empty."); + } + + int bitArrayLength = bitArrays.get(0).length; + int bitLength = bitArrays.size() * bitArrayLength; + int byteLength = (bitLength + 7) / 8; + byte[] packedArray = new byte[byteLength]; + + int bitPosition = 0; + for (byte[] bitArray : bitArrays) { + if (bitArray == null) { + throw new IllegalArgumentException("Bit array cannot be null."); + } + if (bitArray.length != bitArrayLength) { + throw new IllegalArgumentException("All bit arrays must have the same length."); + } + + for (byte bit : bitArray) { + int byteIndex = bitPosition / 8; + int bitIndex = 7 - (bitPosition % 8); + if (bit == 1) { + packedArray[byteIndex] |= (1 << bitIndex); + } + bitPosition++; + } + } + + return packedArray; + } +} + diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java new file mode 100644 index 0000000000..d1365fdb5b --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.util; + +import lombok.experimental.UtilityClass; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.ByteArrayOutputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.io.IOException; +import java.io.ByteArrayInputStream; +import java.io.ObjectInputStream; + +/** + * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing + * QuantizationState objects along with their specific data. + */ +@UtilityClass +public class QuantizationStateSerializer { + + + /** + * A functional interface for deserializing specific data associated with a QuantizationState. + */ + @FunctionalInterface + public interface SerializableDeserializer { + QuantizationState deserialize(QuantizationParams parentParams, Serializable specificData); + } + + /** + * Serializes the QuantizationState and specific data into a byte array. + * + * @param state The QuantizationState to serialize. + * @param specificData The specific data related to the state, to be serialized. + * @return A byte array representing the serialized state and specific data. + * @throws IOException If an I/O error occurs during serialization. + */ + public static byte[] serialize(QuantizationState state, Serializable specificData) throws IOException { + byte[] parentBytes = serializeParentParams(state.getQuantizationParams()); + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(bos)) { + out.writeInt(parentBytes.length); // Write the length of the parent bytes + out.write(parentBytes); // Write the parent bytes + out.writeObject(specificData); // Write the specific data + out.flush(); + return bos.toByteArray(); + } + } + + /** + * Deserializes a QuantizationState and its specific data from a byte array. + * + * @param bytes The byte array containing the serialized data. + * @param specificDataDeserializer The deserializer for the specific data associated with the state. + * @return The deserialized QuantizationState including its specific data. + * @throws IOException If an I/O error occurs during deserialization. + * @throws ClassNotFoundException If the class of the serialized object cannot be found. + */ + public static QuantizationState deserialize(byte[] bytes, SerializableDeserializer specificDataDeserializer) + throws IOException, ClassNotFoundException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + ObjectInputStream in = new ObjectInputStream(bis)) { + int parentLength = in.readInt(); // Read the length of the + // Read the length of the parent bytes + byte[] parentBytes = new byte[parentLength]; + in.readFully(parentBytes); // Read the parent bytes + QuantizationParams parentParams = deserializeParentParams(parentBytes); // Deserialize the parent params + Serializable specificData = (Serializable) in.readObject(); // Read the specific data + return specificDataDeserializer.deserialize(parentParams, specificData); + } + } + + /** + * Serializes the parent parameters of the QuantizationState into a byte array. + * + * @param params The QuantizationParams to serialize. + * @return A byte array representing the serialized parent parameters. + * @throws IOException If an I/O error occurs during serialization. + */ + private static byte[] serializeParentParams(QuantizationParams params) throws IOException { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); + ObjectOutputStream out = new ObjectOutputStream(bos)) { + out.writeObject(params); + out.flush(); + return bos.toByteArray(); + } + } + + /** + * Deserializes the parent parameters of the QuantizationState from a byte array. + * + * @param bytes The byte array containing the serialized parent parameters. + * @return The deserialized QuantizationParams. + * @throws IOException If an I/O error occurs during deserialization. + * @throws ClassNotFoundException If the class of the serialized object cannot be found. + */ + private static QuantizationParams deserializeParentParams(byte[] bytes) + throws IOException, ClassNotFoundException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); + ObjectInputStream in = new ObjectInputStream(bis)) { + return (QuantizationParams) in.readObject(); + } + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java new file mode 100644 index 0000000000..4a48d638b4 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.quantization.util; + +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import lombok.experimental.UtilityClass; + +/** + * Utility class providing common methods for quantizer operations, such as parameter validation and + * extraction. This class is designed to be used with various quantizer implementations that require + * consistent handling of training requests and sampled indices. + */ +@UtilityClass +public class QuantizerHelper { + + /** + * Validates the provided training request to ensure it contains non-null quantization parameters. + * Extracts and returns the SQParams from the training request. + * + * @param trainingRequest the training request to validate and extract parameters from. + * @return the extracted SQParams. + * @throws IllegalArgumentException if the SQParams are null. + */ + public static SQParams validateAndExtractParams(TrainingRequest trainingRequest) { + QuantizationParams params = trainingRequest.getParams(); + if (params == null || !(params instanceof SQParams)) { + throw new IllegalArgumentException("Quantization parameters must not be null and must be of type SQParams."); + } + return (SQParams) params; + } + + /** + * Calculates the mean vector from a set of sampled vectors. + * + *

This method takes a {@link TrainingRequest} object and an array of sampled indices, + * retrieves the vectors corresponding to these indices, and calculates the mean vector. + * Each element of the mean vector is computed as the average of the corresponding elements + * of the sampled vectors.

+ * + * @param samplingRequest The {@link TrainingRequest} containing the dataset and methods to access vectors by their indices. + * @param sampledIndices An array of indices representing the sampled vectors to be used for mean calculation. + * @return A float array representing the mean vector of the sampled vectors. + * @throws IllegalArgumentException If any of the vectors at the sampled indices are null. + * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors. + */ + public static float[] calculateMean(TrainingRequest samplingRequest, int[] sampledIndices) { + int totalSamples = sampledIndices.length; + float[] mean = null; + for (int index : sampledIndices) { + float[] vector = samplingRequest.getVectorByDocId(index); + if (vector == null) { + throw new IllegalArgumentException("Vector at sampled index " + index + " is null."); + } + if (mean == null) { + mean = new float[vector.length]; + } + for (int j = 0; j < vector.length; j++) { + mean[j] += vector[j]; + } + } + if (mean == null) { + throw new IllegalStateException("Mean array should not be null after processing vectors."); + } + for (int j = 0; j < mean.length; j++) { + mean[j] /= totalSamples; + } + return mean; + } + + /** + * Calculates the sum, sum of squares, mean, and standard deviation for each dimension in a single pass. + * + * @param trainingRequest the request containing the data and parameters for training. + * @param sampledIndices the indices of the sampled vectors. + * @param sumAndMean the array to store the sum and then the mean of each dimension. + * @param sumSqAndStdDev the array to store the sum of squares and then the standard deviation of each dimension. + */ + public static void calculateSumMeanAndStdDev( + TrainingRequest trainingRequest, + int[] sampledIndices, + float[] sumAndMean, + float[] sumSqAndStdDev + ) { + int totalSamples = sampledIndices.length; + int dimension = sumAndMean.length; + + // Single pass to calculate sum and sum of squares + for (int index : sampledIndices) { + float[] vector = trainingRequest.getVectorByDocId(index); + if (vector == null) { + throw new IllegalArgumentException("Vector at sampled index " + index + " is null."); + } + for (int j = 0; j < dimension; j++) { + sumAndMean[j] += vector[j]; + sumSqAndStdDev[j] += vector[j] * vector[j]; + } + } + + // Calculate mean and standard deviation in one pass + for (int j = 0; j < dimension; j++) { + sumAndMean[j] = sumAndMean[j] / totalSamples; + sumSqAndStdDev[j] = (float) Math.sqrt((sumSqAndStdDev[j] / totalSamples) - (sumAndMean[j] * sumAndMean[j])); + } + } +} + diff --git a/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java new file mode 100644 index 0000000000..f14babd379 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +import org.opensearch.knn.KNNTestCase; + +public class QuantizationTypeTests extends KNNTestCase { + + public void testQuantizationTypeValues() { + QuantizationType[] expectedValues = { + QuantizationType.SPACE_QUANTIZATION, + QuantizationType.VALUE_QUANTIZATION + }; + assertArrayEquals(expectedValues, QuantizationType.values()); + } + + public void testQuantizationTypeValueOf() { + assertEquals(QuantizationType.SPACE_QUANTIZATION, QuantizationType.valueOf("SPACE_QUANTIZATION")); + assertEquals(QuantizationType.VALUE_QUANTIZATION, QuantizationType.valueOf("VALUE_QUANTIZATION")); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java b/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java new file mode 100644 index 0000000000..0ba7839cd1 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +import org.opensearch.knn.KNNTestCase; + +public class SQTypesTests extends KNNTestCase { + public void testSQTypesValues() { + SQTypes[] expectedValues = { + SQTypes.FP16, + SQTypes.FP8, + SQTypes.INT8, + SQTypes.INT6, + SQTypes.INT4, + SQTypes.ONE_BIT, + SQTypes.TWO_BIT, + SQTypes.FOUR_BIT, + SQTypes.UNSUPPORTED_TYPE + }; + assertArrayEquals(expectedValues, SQTypes.values()); + } + + public void testSQTypesValueOf() { + assertEquals(SQTypes.FP16, SQTypes.valueOf("FP16")); + assertEquals(SQTypes.INT8, SQTypes.valueOf("INT8")); + assertEquals(SQTypes.INT6, SQTypes.valueOf("INT6")); + assertEquals(SQTypes.INT4, SQTypes.valueOf("INT4")); + assertEquals(SQTypes.ONE_BIT, SQTypes.valueOf("ONE_BIT")); + assertEquals(SQTypes.TWO_BIT, SQTypes.valueOf("TWO_BIT")); + assertEquals(SQTypes.FOUR_BIT, SQTypes.valueOf("FOUR_BIT")); + assertEquals(SQTypes.UNSUPPORTED_TYPE, SQTypes.valueOf("UNSUPPORTED_TYPE")); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java new file mode 100644 index 0000000000..47d7123f63 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +import org.opensearch.knn.KNNTestCase; + +public class ValueQuantizationTypeTests extends KNNTestCase { + public void testValueQuantizationTypeValues() { + ValueQuantizationType[] expectedValues = { + ValueQuantizationType.SQ + }; + assertArrayEquals(expectedValues, ValueQuantizationType.values()); + } + + public void testValueQuantizationTypeValueOf() { + assertEquals(ValueQuantizationType.SQ, ValueQuantizationType.valueOf("SQ")); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java new file mode 100644 index 0000000000..48dbd0bea8 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import org.junit.Before; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +import java.lang.reflect.Field; +import java.util.concurrent.atomic.AtomicBoolean; + +public class QuantizerFactoryTests extends KNNTestCase { + + @Before + public void resetIsRegisteredFlag() throws NoSuchFieldException, IllegalAccessException { + Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered"); + isRegisteredField.setAccessible(true); + AtomicBoolean isRegistered = (AtomicBoolean) isRegisteredField.get(null); + isRegistered.set(false); + } + + public void test_Lazy_Registration() { + SQParams params = new SQParams(SQTypes.ONE_BIT); + assertFalse(isRegisteredFieldAccessible()); + Quantizer quantizer = QuantizerFactory.getQuantizer(params); + assertTrue(quantizer instanceof OneBitScalarQuantizer); + assertTrue(isRegisteredFieldAccessible()); + } + + public void testGetQuantizer_withOneBitSQParams () { + SQParams params = new SQParams(SQTypes.ONE_BIT); + Quantizer quantizer = QuantizerFactory.getQuantizer(params); + assertTrue(quantizer instanceof OneBitScalarQuantizer); + } + + public void testGetQuantizer_withTwoBitSQParams () { + SQParams params = new SQParams(SQTypes.TWO_BIT); + Quantizer quantizer = QuantizerFactory.getQuantizer(params); + assertTrue(quantizer instanceof MultiBitScalarQuantizer); + } + + public void testGetQuantizer_withFourBitSQParams () { + SQParams params = new SQParams(SQTypes.FOUR_BIT); + Quantizer quantizer = QuantizerFactory.getQuantizer(params); + assertTrue(quantizer instanceof MultiBitScalarQuantizer); + } + + public void testGetQuantizer_withUnsupportedType () { + SQParams params = new SQParams(SQTypes.UNSUPPORTED_TYPE); + try { + QuantizerFactory.getQuantizer(params); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("No quantizer registered for type identifier")); + } + } + public void testGetQuantizer_withNullParams () { + try { + QuantizerFactory.getQuantizer(null); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException e) { + assertEquals("Quantization parameters must not be null.", e.getMessage()); + } + } + + public void testConcurrentRegistration () throws InterruptedException { + Runnable task = () -> { + SQParams params = new SQParams(SQTypes.ONE_BIT); + QuantizerFactory.getQuantizer(params); + }; + + Thread thread1 = new Thread(task); + Thread thread2 = new Thread(task); + thread1.start(); + thread2.start(); + thread1.join(); + thread2.join(); + assertTrue(isRegisteredFieldAccessible()); + } + + private boolean isRegisteredFieldAccessible () { + try { + Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered"); + isRegisteredField.setAccessible(true); + AtomicBoolean isRegistered = (AtomicBoolean) isRegisteredField.get(null); + return isRegistered.get(); + } catch (NoSuchFieldException | IllegalAccessException e) { + fail("Failed to access isRegistered field."); + return false; + } + } +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java new file mode 100644 index 0000000000..a848bc474c --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import org.junit.BeforeClass; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.QuantizationType; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +public class QuantizerRegistryTests extends KNNTestCase { + + @BeforeClass + public static void setup() { + // Register the quantizers for testing with enums + QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE_QUANTIZATION, SQTypes.ONE_BIT, OneBitScalarQuantizer::new); + QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE_QUANTIZATION, SQTypes.TWO_BIT, () -> new MultiBitScalarQuantizer(2)); + QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE_QUANTIZATION, SQTypes.FOUR_BIT, () -> new MultiBitScalarQuantizer(4)); + } + + public void testRegisterAndGetQuantizer() { + // Test for OneBitScalarQuantizer + SQParams oneBitParams = new SQParams(SQTypes.ONE_BIT); + Quantizer oneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); + assertTrue(oneBitQuantizer instanceof OneBitScalarQuantizer); + + // Test for MultiBitScalarQuantizer (2-bit) + SQParams twoBitParams = new SQParams(SQTypes.TWO_BIT); + Quantizer twoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); + assertTrue(twoBitQuantizer instanceof MultiBitScalarQuantizer); + + // Test for MultiBitScalarQuantizer (4-bit) + SQParams fourBitParams = new SQParams(SQTypes.FOUR_BIT); + Quantizer fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); + assertTrue(fourBitQuantizer instanceof MultiBitScalarQuantizer); + } + + public void testGetQuantizer_withUnsupportedTypeIdentifier() { + // Create SQParams with an unsupported type identifier + SQParams params = new SQParams(SQTypes.UNSUPPORTED_TYPE); // Assuming UNSUPPORTED_TYPE is not registered + + // Expect IllegalArgumentException when requesting a quantizer with unsupported params + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + QuantizerRegistry.getQuantizer(params); + }); + + assertTrue(exception.getMessage().contains("No quantizer registered for type identifier")); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java new file mode 100644 index 0000000000..50a8eee60d --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizationState; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; + +import java.io.IOException; + +public class QuantizationStateSerializerTests extends KNNTestCase { + + public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException, ClassNotFoundException { + SQParams params = new SQParams(SQTypes.ONE_BIT); + float[] mean = new float[]{0.1f, 0.2f, 0.3f}; + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + + byte[] serialized = state.toByteArray(); + OneBitScalarQuantizationState deserialized = OneBitScalarQuantizationState.fromByteArray(serialized); + + assertArrayEquals(mean, deserialized.getMean(), 0.0f); + assertEquals(params, deserialized.getQuantizationParams()); + } + + public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws IOException, ClassNotFoundException { + SQParams params = new SQParams(SQTypes.TWO_BIT); + float[][] thresholds = new float[][]{ + {0.1f, 0.2f, 0.3f}, + {0.4f, 0.5f, 0.6f} + }; + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + byte[] serialized = state.toByteArray(); + MultiBitScalarQuantizationState deserialized = MultiBitScalarQuantizationState.fromByteArray(serialized); + + assertArrayEquals(thresholds, deserialized.getThresholds()); + assertEquals(params, deserialized.getQuantizationParams()); + } +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java new file mode 100644 index 0000000000..61d80b7995 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.quantization.quantizationState; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; + +import java.io.IOException; + +public class QuantizationStateTests extends KNNTestCase { + + public void testOneBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException { + SQParams params = new SQParams(SQTypes.ONE_BIT); + float[] mean = {1.0f, 2.0f, 3.0f}; + + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + + byte[] serializedState = state.toByteArray(); + OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState); + float delta = 0.0001f; + + assertArrayEquals(mean, deserializedState.getMean(), delta); + assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType()); + } + + + public void testMultiBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException { + SQParams params = new SQParams(SQTypes.TWO_BIT); + float[][] thresholds = { + {0.5f, 1.5f, 2.5f}, + {1.0f, 2.0f, 3.0f} + }; + + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + byte[] serializedState = state.toByteArray(); + MultiBitScalarQuantizationState deserializedState = MultiBitScalarQuantizationState.fromByteArray(serializedState); + float delta = 0.0001f; + + for (int i = 0; i < thresholds.length; i++) { + assertArrayEquals(thresholds[i], deserializedState.getThresholds()[i],delta); + } + assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType()); + } + + public void testDefaultQuantizationStateSerialization() throws IOException, ClassNotFoundException { + SQParams params = new SQParams(SQTypes.UNSUPPORTED_TYPE); + + DefaultQuantizationState state = new DefaultQuantizationState(params); + + byte[] serializedState = state.toByteArray(); + DefaultQuantizationState deserializedState = DefaultQuantizationState.fromByteArray(serializedState); + + assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType()); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java new file mode 100644 index 0000000000..c7f06dd709 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java @@ -0,0 +1,137 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; + +public class MultiBitScalarQuantizerTests extends KNNTestCase { + + public void testTrain_twoBit() { + float[][] vectors = { + {0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, + {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f} + }; + MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + int[] sampledIndices = {0, 1, 2}; + SQParams params = new SQParams(SQTypes.TWO_BIT); + TrainingRequest request = new MockTrainingRequest(params, vectors); + request.setSampledIndices(sampledIndices); + QuantizationState state = twoBitQuantizer.train(request); + + assertTrue(state instanceof MultiBitScalarQuantizationState); + MultiBitScalarQuantizationState mbState = (MultiBitScalarQuantizationState) state; + assertNotNull(mbState.getThresholds()); + assertEquals(2, mbState.getThresholds().length); // 2-bit quantization should have 2 thresholds + } + + public void testTrain_fourBit() { + MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4); + float[][] vectors = { + {0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f}, + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, + {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f} + }; + int[] sampledIndices = {0, 1, 2}; + SQParams params = new SQParams(SQTypes.FOUR_BIT); + TrainingRequest request = new MockTrainingRequest(params, vectors); + request.setSampledIndices(sampledIndices); + QuantizationState state = fourBitQuantizer.train(request); + + assertTrue(state instanceof MultiBitScalarQuantizationState); + MultiBitScalarQuantizationState mbState = (MultiBitScalarQuantizationState) state; + assertNotNull(mbState.getThresholds()); + assertEquals(4, mbState.getThresholds().length); // 4-bit quantization should have 4 thresholds + } + + public void testQuantize_twoBit() { + MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + float[] vector = {1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f}; + float[][] thresholds = { + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, + {1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f} + }; + SQParams params = new SQParams(SQTypes.TWO_BIT); + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + QuantizationOutput output = twoBitQuantizer.quantize(vector, state); + assertNotNull(output.getQuantizedVector()); + assertEquals(2, output.getQuantizedVector().length); + } + + public void testQuantize_fourBit() { + MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4); + float[] vector = {1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f}; + float[][] thresholds = { + {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, + {1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f, 8.1f}, + {1.2f, 2.2f, 3.2f, 4.2f, 5.2f, 6.2f, 7.2f, 8.2f}, + {1.3f, 2.3f, 3.3f, 4.3f, 5.3f, 6.3f, 7.3f, 8.3f} + }; + SQParams params = new SQParams(SQTypes.FOUR_BIT); + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + QuantizationOutput output = fourBitQuantizer.quantize(vector, state); + assertEquals(4, output.getQuantizedVector().length); + assertNotNull(output.getQuantizedVector()); + } + + public void testQuantize_withNullVector() { + MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + expectThrows(IllegalArgumentException.class, + () -> twoBitQuantizer.quantize(null, new MultiBitScalarQuantizationState(new SQParams(SQTypes.TWO_BIT), + new float[2][8]))); + } + + public void testQuantize_withInvalidState() { + MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + float[] vector = {1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f}; + QuantizationState invalidState = new MockInvalidQuantizationState(); + expectThrows(IllegalArgumentException.class, + () -> twoBitQuantizer.quantize(vector, invalidState)); + } + + public void testQuantize_withDefaultQuantizationState() { + MultiBitScalarQuantizer quantizer = new MultiBitScalarQuantizer(2); + float[] vector = {1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f}; + DefaultQuantizationState state = new DefaultQuantizationState(new SQParams(SQTypes.ONE_BIT)); + + expectThrows(UnsupportedOperationException.class, () -> quantizer.quantize(vector, state)); + } + + // Mock classes for testing + private static class MockTrainingRequest extends TrainingRequest { + private final float[][] vectors; + + public MockTrainingRequest(SQParams params, float[][] vectors) { + super(params, vectors.length); + this.vectors = vectors; + } + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + } + + private static class MockInvalidQuantizationState implements QuantizationState { + @Override + public SQParams getQuantizationParams() { + return new SQParams(SQTypes.UNSUPPORTED_TYPE); + } + + @Override + public byte[] toByteArray() { + return new byte[0]; + } + } +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java new file mode 100644 index 0000000000..9f43f716d1 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java @@ -0,0 +1,137 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.SQTypes; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.SQParams; +import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplingFactory; +import org.opensearch.knn.quantization.util.QuantizerHelper; + +public class OneBitScalarQuantizerTests extends KNNTestCase { + + public void testTrain_withTrainingRequired() { + float[][] vectors = { + {1.0f, 2.0f, 3.0f}, + {4.0f, 5.0f, 6.0f}, + {7.0f, 8.0f, 9.0f} + }; + + SQParams params = new SQParams(SQTypes.ONE_BIT); + TrainingRequest originalRequest = new TrainingRequest(params, vectors.length) { + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + }; + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + QuantizationState state = quantizer.train(originalRequest); + + assertTrue(state instanceof OneBitScalarQuantizationState); + float[] mean = ((OneBitScalarQuantizationState) state).getMean(); + assertArrayEquals(new float[]{4.0f, 5.0f, 6.0f}, mean, 0.001f); + } + + public void testQuantize_withState() { + float[] vector = {3.0f, 6.0f, 9.0f}; + float[] thresholds = {4.0f, 5.0f, 6.0f}; + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(SQTypes.ONE_BIT), thresholds); + + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + QuantizationOutput output = quantizer.quantize(vector, state); + + assertNotNull(output); + byte[] expectedPackedBits = new byte[]{0b01100000}; // or 96 in decimal + assertArrayEquals(expectedPackedBits, output.getQuantizedVector()); + } + + public void testQuantize_withDefaultQuantizationState() { + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + float[] vector = {3.0f, 6.0f, 9.0f}; + DefaultQuantizationState state = new DefaultQuantizationState(new SQParams(SQTypes.ONE_BIT)); + + expectThrows(UnsupportedOperationException.class, () -> quantizer.quantize(vector, state)); + } + + public void testQuantize_withNullVector() { + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(SQTypes.ONE_BIT), new float[]{0.0f}); + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state)); + } + + public void testQuantize_withInvalidState() { + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + float[] vector = {1.0f, 2.0f, 3.0f}; + QuantizationState invalidState = new QuantizationState() { + @Override + public SQParams getQuantizationParams() { + return new SQParams(SQTypes.ONE_BIT); + } + + @Override + public byte[] toByteArray() { + return new byte[0]; + } + }; + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, invalidState)); + } + + public void testQuantize_withMismatchedDimensions() { + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + float[] vector = {1.0f, 2.0f, 3.0f}; + float[] thresholds = {4.0f, 5.0f}; + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(SQTypes.ONE_BIT), thresholds); + + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state)); + } + + public void testCalculateMean() { + float[][] vectors = { + {1.0f, 2.0f, 3.0f}, + {4.0f, 5.0f, 6.0f}, + {7.0f, 8.0f, 9.0f} + }; + + SQParams params = new SQParams(SQTypes.ONE_BIT); + TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) { + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + }; + + Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); + int[] sampledIndices = sampler.sample(vectors.length, 3); + float[] mean = QuantizerHelper.calculateMean(samplingRequest, sampledIndices); + assertArrayEquals(new float[]{4.0f, 5.0f, 6.0f}, mean, 0.001f); + } + + public void testCalculateMean_withNullVector() { + float[][] vectors = { + {1.0f, 2.0f, 3.0f}, + null, + {7.0f, 8.0f, 9.0f} + }; + + SQParams params = new SQParams(SQTypes.ONE_BIT); + TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) { + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + }; + + Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); + int[] sampledIndices = sampler.sample(vectors.length, 3); + expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMean(samplingRequest, sampledIndices)); + } +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java new file mode 100644 index 0000000000..ac73743950 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java @@ -0,0 +1,103 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +import org.opensearch.knn.KNNTestCase; + +import java.util.Arrays; +import java.util.Random; +import java.util.stream.IntStream; + + +public class ReservoirSamplerTests extends KNNTestCase { + + public void testSampleLessThanSampleSize() { + ReservoirSampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 5; + int sampleSize = 10; + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray(); + assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); + } + + public void testSampleEqualToSampleSize() { + ReservoirSampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 10; + int sampleSize = 10; + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray(); + assertArrayEquals("Sampled indices should include all available indices.",expectedIndices, sampledIndices); + } + public void testSampleGreaterThanSampleSize() { + ReservoirSampler sampler = new ReservoirSampler(12345); // Fixed seed for reproducibility + int totalNumberOfVectors = 100; + int sampleSize = 10; + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(sampleSize, sampledIndices.length); + assertTrue(Arrays.stream(sampledIndices).allMatch(i -> i >= 0 && i < totalNumberOfVectors)); + } + + public void testSampleReproducibility() { + long seed = 12345L; + ReservoirSampler sampler1 = new ReservoirSampler(seed); + ReservoirSampler sampler2 = new ReservoirSampler(seed); + int totalNumberOfVectors = 100; + int sampleSize = 10; + + int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); + int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); + + assertArrayEquals(sampledIndices1, sampledIndices2); + } + + public void testSampleRandomness() { + ReservoirSampler sampler1 = new ReservoirSampler(); + ReservoirSampler sampler2 = new ReservoirSampler(); + int totalNumberOfVectors = 100; + int sampleSize = 10; + + int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); + int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); + + assertNotEquals(Arrays.toString(sampledIndices1), Arrays.toString(sampledIndices2)); + } + + public void testEdgeCaseZeroVectors() { + ReservoirSampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 0; + int sampleSize = 10; + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(0, sampledIndices.length); + } + + public void testEdgeCaseZeroSampleSize() { + ReservoirSampler sampler = new ReservoirSampler(); + int totalNumberOfVectors = 10; + int sampleSize = 0; + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(0, sampledIndices.length); + } + + public void testReservoirSamplingAlgorithm() { + ReservoirSampler sampler = new ReservoirSampler(12345); // Fixed seed for reproducibility + int totalNumberOfVectors = 100; + int sampleSize = 10; + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + + // Manually verify the reservoir sampling algorithm logic + int[] reservoir = IntStream.range(0, sampleSize).toArray(); + Random random = new Random(12345); + for (int i = sampleSize; i < totalNumberOfVectors; i++) { + int j = random.nextInt(i + 1); + if (j < sampleSize) { + reservoir[j] = i; + } + } + + assertArrayEquals(reservoir, sampledIndices); + } +} + diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java new file mode 100644 index 0000000000..56d496d2f3 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +import org.opensearch.knn.KNNTestCase; + +public class SamplingFactoryTests extends KNNTestCase { + public void testGetSampler_withReservoir() { + Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); + assertTrue(sampler instanceof ReservoirSampler); + } + + public void testGetSampler_withUnsupportedType() { + expectThrows( NullPointerException.class, ()-> SamplingFactory.getSampler(null)); // This should throw an exception + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java b/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java new file mode 100644 index 0000000000..c0589588c2 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.util; + +import org.opensearch.knn.KNNTestCase; + +import java.util.Arrays; +import java.util.List; + +public class BitPackingUtilsTests extends KNNTestCase { + + public void testPackBits() { + List bitArrays = Arrays.asList( + new byte[]{0, 1, 0, 1, 1, 0, 1, 1}, + new byte[]{1, 0, 1, 0, 0, 1, 0, 0} + ); + + byte[] expectedPackedArray = new byte[]{(byte) 0b01011011, (byte) 0b10100100}; + byte[] packedArray = BitPackingUtils.packBits(bitArrays); + + assertArrayEquals(expectedPackedArray, packedArray); + } + + public void testPackBitsEmptyList() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + BitPackingUtils.packBits(Arrays.asList()); + }); + assertEquals("The list of bit arrays cannot be empty.", exception.getMessage()); + } + + public void testPackBitsNullBitArray() { + List bitArrays = Arrays.asList( + new byte[]{0, 1, 0, 1, 1, 0, 1, 1}, + null + ); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + BitPackingUtils.packBits(bitArrays); + }); + assertEquals("Bit array cannot be null.", exception.getMessage()); + } + + public void testPackBitsInconsistentLength() { + List bitArrays = Arrays.asList( + new byte[]{0, 1, 0, 1, 1, 0, 1, 1}, + new byte[]{1, 0, 1} + ); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + BitPackingUtils.packBits(bitArrays); + }); + assertEquals("All bit arrays must have the same length.", exception.getMessage()); + } + + public void testPackBitsEdgeCaseSingleBitArray() { + List bitArrays = Arrays.asList( + new byte[]{1} + ); + + byte[] expectedPackedArray = new byte[]{(byte) 0b10000000}; + byte[] packedArray = BitPackingUtils.packBits(bitArrays); + + assertArrayEquals("Packed array does not match expected output.",expectedPackedArray, packedArray); + } + public void testPackBitsEdgeCaseSingleBit() { + List bitArrays = Arrays.asList( + new byte[]{1, 0, 1, 0, 1, 0, 1, 0}, + new byte[]{1, 1, 1, 1, 1, 1, 1, 1} + ); + + byte[] expectedPackedArray = new byte[]{(byte) 0b10101010, (byte) 0b11111111}; + byte[] packedArray = BitPackingUtils.packBits(bitArrays); + + assertArrayEquals("Packed array does not match expected output.", expectedPackedArray, packedArray); + } + + public void testPackBits_emptyArray() { + List bitArrays = Arrays.asList(); + expectThrows(IllegalArgumentException.class, () -> { + BitPackingUtils.packBits(bitArrays); + });; + } +}