diff --git a/CHANGELOG.md b/CHANGELOG.md
index eb8427b1f..dfca7bece 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -35,3 +35,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925)
* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877)
* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939)
+* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)
diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java
new file mode 100644
index 000000000..40347ad93
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java
@@ -0,0 +1,62 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+import lombok.Getter;
+
+/**
+ * The ScalarQuantizationType enum defines the various scalar quantization types that can be used
+ * for vector quantization. Each type corresponds to a different bit-width representation of the quantized values.
+ *
+ *
+ * Future Developers: If you change the name of any enum constant, do not change its associated value.
+ * Serialization and deserialization depend on these values to maintain compatibility.
+ *
+ */
+@Getter
+public enum ScalarQuantizationType {
+ /**
+ * ONE_BIT quantization uses a single bit per coordinate.
+ */
+ ONE_BIT(1),
+
+ /**
+ * TWO_BIT quantization uses two bits per coordinate.
+ */
+ TWO_BIT(2),
+
+ /**
+ * FOUR_BIT quantization uses four bits per coordinate.
+ */
+ FOUR_BIT(4);
+
+ private final int id;
+
+ /**
+ * Constructs a ScalarQuantizationType with the specified ID.
+ *
+ * @param id the ID representing the quantization type.
+ */
+ ScalarQuantizationType(int id) {
+ this.id = id;
+ }
+
+ /**
+ * Returns the ScalarQuantizationType associated with the given ID.
+ *
+ * @param id the ID of the quantization type.
+ * @return the corresponding ScalarQuantizationType.
+ * @throws IllegalArgumentException if the ID does not correspond to any ScalarQuantizationType.
+ */
+ public static ScalarQuantizationType fromId(int id) {
+ for (ScalarQuantizationType type : ScalarQuantizationType.values()) {
+ if (type.getId() == id) {
+ return type;
+ }
+ }
+ throw new IllegalArgumentException("Unknown ScalarQuantizationType ID: " + id);
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java
new file mode 100644
index 000000000..b99f6ebdc
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java
@@ -0,0 +1,54 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import lombok.AccessLevel;
+import lombok.NoArgsConstructor;
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.quantizer.Quantizer;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * The QuantizerFactory class is responsible for creating instances of {@link Quantizer}
+ * based on the provided {@link QuantizationParams}. It uses a registry to look up the
+ * appropriate quantizer implementation for the given quantization parameters.
+ */
+@NoArgsConstructor(access = AccessLevel.PRIVATE)
+public final class QuantizerFactory {
+ private static final AtomicBoolean isRegistered = new AtomicBoolean(false);
+
+ /**
+ * Ensures that default quantizers are registered.
+ */
+ private static void ensureRegistered() {
+ if (!isRegistered.get()) {
+ synchronized (QuantizerFactory.class) {
+ if (!isRegistered.get()) {
+ QuantizerRegistrar.registerDefaultQuantizers();
+ isRegistered.set(true);
+ }
+ }
+ }
+ }
+
+ /**
+ * Retrieves a quantizer instance based on the provided quantization parameters.
+ *
+ * @param params the quantization parameters used to determine the appropriate quantizer
+ * @param the type of quantization parameters, extending {@link QuantizationParams}
+ * @param the type of the quantized output
+ * @return an instance of {@link Quantizer} corresponding to the provided parameters
+ */
+ public static Quantizer
getQuantizer(final P params) {
+ if (params == null) {
+ throw new IllegalArgumentException("Quantization parameters must not be null.");
+ }
+ // Lazy Registration instead of static block as class level;
+ ensureRegistered();
+ return QuantizerRegistry.getQuantizer(params);
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java
new file mode 100644
index 000000000..7b542aea0
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import lombok.AccessLevel;
+import lombok.NoArgsConstructor;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
+
+/**
+ * The QuantizerRegistrar class is responsible for registering default quantizers.
+ * This class ensures that the registration happens only once in a thread-safe manner.
+ */
+@NoArgsConstructor(access = AccessLevel.PRIVATE)
+final class QuantizerRegistrar {
+
+ /**
+ * Registers default quantizers
+ *
+ * This method is synchronized to ensure that registration occurs only once,
+ * even in a multi-threaded environment.
+ *
+ */
+ static synchronized void registerDefaultQuantizers() {
+ // Register OneBitScalarQuantizer for SQParams with VALUE_QUANTIZATION and SQTypes.ONE_BIT
+ QuantizerRegistry.register(
+ ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT),
+ new OneBitScalarQuantizer()
+ );
+ // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 2
+ QuantizerRegistry.register(
+ ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.TWO_BIT),
+ new MultiBitScalarQuantizer(2)
+ );
+ // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 4
+ QuantizerRegistry.register(
+ ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT),
+ new MultiBitScalarQuantizer(4)
+ );
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java
new file mode 100644
index 000000000..ac266f547
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java
@@ -0,0 +1,59 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import lombok.AccessLevel;
+import lombok.NoArgsConstructor;
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.quantizer.Quantizer;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * The QuantizerRegistry class is responsible for managing the registration and retrieval
+ * of quantizer instances. Quantizers are registered with specific quantization parameters
+ * and type identifiers, allowing for efficient lookup and instantiation.
+ */
+@NoArgsConstructor(access = AccessLevel.PRIVATE)
+final class QuantizerRegistry {
+ // ConcurrentHashMap for thread-safe access
+ private static final Map> registry = new ConcurrentHashMap<>();
+
+ /**
+ * Registers a quantizer with the registry.
+ *
+ * @param paramIdentifier the unique identifier for the quantization parameters
+ * @param quantizer an instance of the quantizer
+ */
+ static void register(final String paramIdentifier, final Quantizer, ?> quantizer) {
+ // Check if the quantizer is already registered for the given identifier
+ if (registry.putIfAbsent(paramIdentifier, quantizer) != null) {
+ // Throw an exception if a quantizer is already registered
+ throw new IllegalArgumentException("Quantizer already registered for identifier: " + paramIdentifier);
+ }
+ }
+
+ /**
+ * Retrieves a quantizer instance based on the provided quantization parameters.
+ *
+ * @param params the quantization parameters used to determine the appropriate quantizer
+ * @param the type of quantization parameters
+ * @param the type of the quantized output
+ * @return an instance of {@link Quantizer} corresponding to the provided parameters
+ * @throws IllegalArgumentException if no quantizer is registered for the given parameters
+ */
+ static Quantizer
getQuantizer(final P params) {
+ String identifier = params.getTypeIdentifier();
+ Quantizer, ?> quantizer = registry.get(identifier);
+ if (quantizer == null) {
+ throw new IllegalArgumentException("No quantizer registered for type identifier: " + identifier);
+ }
+ @SuppressWarnings("unchecked")
+ Quantizer
typedQuantizer = (Quantizer
) quantizer;
+ return typedQuantizer;
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java
new file mode 100644
index 000000000..95592fcb9
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationOutput;
+
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+
+import java.util.Arrays;
+
+/**
+ * The BinaryQuantizationOutput class represents the output of a quantization process in binary format.
+ * It implements the QuantizationOutput interface to handle byte arrays specifically.
+ */
+@NoArgsConstructor
+public class BinaryQuantizationOutput implements QuantizationOutput {
+ @Getter
+ private byte[] quantizedVector;
+
+ /**
+ * Prepares the quantized vector array based on the provided parameters and returns it for direct modification.
+ * This method ensures that the internal byte array is appropriately sized and cleared before being used.
+ * The method accepts two parameters:
+ *
+ * - bitsPerCoordinate: The number of bits used per coordinate. This determines the granularity of the quantization.
+ * - vectorLength: The length of the original vector that needs to be quantized. This helps in calculating the required byte array size.
+ *
+ * If the existing quantized vector is either null or not the same size as the required byte array,
+ * a new byte array is allocated. Otherwise, the existing array is cleared (i.e., all bytes are set to zero).
+ * This method is designed to be used in conjunction with a bit-packing utility that writes quantized values directly
+ * into the returned byte array.
+ * @param params an array of parameters, where the first parameter is the number of bits per coordinate (int),
+ * and the second parameter is the length of the vector (int).
+ * @return the prepared and writable quantized vector as a byte array.
+ * @throws IllegalArgumentException if the parameters are not as expected (e.g., missing or not integers).
+ */
+ @Override
+ public byte[] prepareAndGetWritableQuantizedVector(Object... params) {
+ if (params.length != 2 || !(params[0] instanceof Integer) || !(params[1] instanceof Integer)) {
+ throw new IllegalArgumentException("Expected two integer parameters: bitsPerCoordinate and vectorLength");
+ }
+ int bitsPerCoordinate = (int) params[0];
+ int vectorLength = (int) params[1];
+ int totalBits = bitsPerCoordinate * vectorLength;
+ int byteLength = (totalBits + 7) >> 3;
+
+ if (this.quantizedVector == null || this.quantizedVector.length != byteLength) {
+ this.quantizedVector = new byte[byteLength];
+ } else {
+ Arrays.fill(this.quantizedVector, (byte) 0);
+ }
+
+ return this.quantizedVector;
+ }
+
+ /**
+ * Returns the quantized vector.
+ *
+ * @return the quantized vector byte array.
+ */
+ @Override
+ public byte[] getQuantizedVector() {
+ return quantizedVector;
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java
new file mode 100644
index 000000000..aa81a8821
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java
@@ -0,0 +1,28 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationOutput;
+
+/**
+ * The QuantizationOutput interface defines the contract for quantization output data.
+ *
+ * @param The type of the quantized data.
+ */
+public interface QuantizationOutput {
+ /**
+ * Returns the quantized vector.
+ *
+ * @return the quantized data.
+ */
+ T getQuantizedVector();
+
+ /**
+ * Prepares and returns the writable quantized vector for direct modification.
+ *
+ * @param params the parameters needed for preparing the quantized vector.
+ * @return the prepared and writable quantized vector.
+ */
+ T prepareAndGetWritableQuantizedVector(Object... params);
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java
new file mode 100644
index 000000000..4f2ee36c5
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java
@@ -0,0 +1,27 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationParams;
+
+import org.opensearch.core.common.io.stream.Writeable;
+
+/**
+ * Interface for quantization parameters.
+ * This interface defines the basic contract for all quantization parameter types.
+ * It provides methods to retrieve the quantization type and a unique type identifier.
+ * Implementations of this interface are expected to provide specific configurations
+ * for various quantization strategies.
+ */
+public interface QuantizationParams extends Writeable {
+ /**
+ * Provides a unique identifier for the quantization parameters.
+ * This identifier is typically a combination of the quantization type
+ * and additional specifics, and it serves to distinguish between different
+ * configurations or modes of quantization.
+ *
+ * @return a string representing the unique type identifier.
+ */
+ String getTypeIdentifier();
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java
new file mode 100644
index 000000000..4e7a53892
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationParams;
+
+import lombok.AllArgsConstructor;
+import lombok.EqualsAndHashCode;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import org.opensearch.core.common.io.stream.StreamInput;
+import org.opensearch.core.common.io.stream.StreamOutput;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+
+import java.io.IOException;
+
+/**
+ * The ScalarQuantizationParams class represents the parameters specific to scalar quantization (SQ).
+ * This class implements the QuantizationParams interface and includes the type of scalar quantization.
+ */
+@Getter
+@AllArgsConstructor
+@NoArgsConstructor // No-argument constructor for deserialization
+@EqualsAndHashCode
+public class ScalarQuantizationParams implements QuantizationParams {
+ private ScalarQuantizationType sqType;
+
+ /**
+ * Static method to generate type identifier based on ScalarQuantizationType.
+ *
+ * @param sqType the scalar quantization type.
+ * @return A string representing the unique type identifier.
+ */
+ public static String generateTypeIdentifier(ScalarQuantizationType sqType) {
+ return generateIdentifier(sqType.getId());
+ }
+
+ /**
+ * Provides a unique type identifier for the ScalarQuantizationParams, combining the SQ type.
+ * This identifier is useful for distinguishing between different configurations of scalar quantization parameters.
+ *
+ * @return A string representing the unique type identifier.
+ */
+ @Override
+ public String getTypeIdentifier() {
+ return generateIdentifier(sqType.getId());
+ }
+
+ private static String generateIdentifier(int id) {
+ return "ScalarQuantizationParams_" + id;
+ }
+
+ /**
+ * Writes the object to the output stream.
+ * This method is part of the Writeable interface and is used to serialize the object.
+ *
+ * @param out the output stream to write the object to.
+ * @throws IOException if an I/O error occurs.
+ */
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeVInt(sqType.getId());
+ }
+
+ /**
+ * Reads the object from the input stream.
+ * This method is part of the Writeable interface and is used to deserialize the object.
+ *
+ * @param in the input stream to read the object from.
+ * @throws IOException if an I/O error occurs.
+ */
+ public ScalarQuantizationParams(StreamInput in, int version) throws IOException {
+ int typeId = in.readVInt();
+ this.sqType = ScalarQuantizationType.fromId(typeId);
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java
new file mode 100644
index 000000000..33e775cad
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import org.opensearch.Version;
+import org.opensearch.core.common.io.stream.StreamInput;
+import org.opensearch.core.common.io.stream.StreamOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+
+import java.io.IOException;
+
+/**
+ * DefaultQuantizationState is used as a fallback state when no training is required or if training fails.
+ * It can be utilized by any quantizer to represent a default state.
+ */
+@Getter
+@NoArgsConstructor // No-argument constructor for deserialization
+@AllArgsConstructor
+public class DefaultQuantizationState implements QuantizationState {
+ private QuantizationParams params;
+
+ @Override
+ public QuantizationParams getQuantizationParams() {
+ return params;
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeInt(Version.CURRENT.id); // Write the version
+ params.writeTo(out);
+ }
+
+ public DefaultQuantizationState(StreamInput in) throws IOException {
+ int version = in.readInt(); // Read the version
+ this.params = new ScalarQuantizationParams(in, version);
+ }
+
+ /**
+ * Serializes the quantization state to a byte array.
+ *
+ * @return a byte array representing the serialized state.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
+ @Override
+ public byte[] toByteArray() throws IOException {
+ return QuantizationStateSerializer.serialize(this);
+ }
+
+ /**
+ * Deserializes a DefaultQuantizationState from a byte array.
+ *
+ * @param bytes the byte array containing the serialized state.
+ * @return the deserialized DefaultQuantizationState.
+ * @throws IOException if an I/O error occurs during deserialization.
+ * @throws ClassNotFoundException if the class of the serialized object cannot be found.
+ */
+ public static DefaultQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
+ return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(bytes, DefaultQuantizationState::new);
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java
new file mode 100644
index 000000000..09092fde8
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java
@@ -0,0 +1,127 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import org.opensearch.Version;
+import org.opensearch.core.common.io.stream.StreamInput;
+import org.opensearch.core.common.io.stream.StreamOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+
+import java.io.IOException;
+
+/**
+ * MultiBitScalarQuantizationState represents the state of multi-bit scalar quantization,
+ * including the thresholds used for quantization.
+ */
+@Getter
+@NoArgsConstructor // No-argument constructor for deserialization
+@AllArgsConstructor
+public final class MultiBitScalarQuantizationState implements QuantizationState {
+ private ScalarQuantizationParams quantizationParams;
+ /**
+ * The threshold values for multi-bit quantization, organized as a 2D array
+ * where each row corresponds to a different bit level.
+ *
+ * For example:
+ * - For 2-bit quantization:
+ * thresholds[0] -> {0.5f, 1.5f, 2.5f} // Thresholds for the first bit level
+ * thresholds[1] -> {1.0f, 2.0f, 3.0f} // Thresholds for the second bit level
+ * - For 4-bit quantization:
+ * thresholds[0] -> {0.1f, 0.2f, 0.3f}
+ * thresholds[1] -> {0.4f, 0.5f, 0.6f}
+ * thresholds[2] -> {0.7f, 0.8f, 0.9f}
+ * thresholds[3] -> {1.0f, 1.1f, 1.2f}
+ *
+ * Each column represents the threshold for a specific dimension in the vector space.
+ */
+ private float[][] thresholds;
+
+ @Override
+ public ScalarQuantizationParams getQuantizationParams() {
+ return quantizationParams;
+ }
+
+ /**
+ * This method is responsible for writing the state of the MultiBitScalarQuantizationState object to an external output.
+ * It includes versioning information to ensure compatibility between different versions of the serialized object.
+ *
+ * @param out the StreamOutput to write the object to.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeVInt(Version.CURRENT.id); // Write the version
+ quantizationParams.writeTo(out);
+ out.writeVInt(thresholds.length); // Number of rows
+ for (float[] row : thresholds) {
+ out.writeFloatArray(row); // Write each row as a float array
+ }
+ }
+
+ /**
+ * This method is responsible for reading the state of the MultiBitScalarQuantizationState object from an external input.
+ * It includes versioning information to ensure compatibility between different versions of the serialized object.
+ *
+ * @param in the StreamInput to read the object from.
+ * @throws IOException if an I/O error occurs during deserialization.
+ */
+ public MultiBitScalarQuantizationState(StreamInput in) throws IOException {
+ int version = in.readVInt(); // Read the version
+ this.quantizationParams = new ScalarQuantizationParams(in, version);
+ int rows = in.readVInt(); // Read the number of rows
+ this.thresholds = new float[rows][];
+ for (int i = 0; i < rows; i++) {
+ this.thresholds[i] = in.readFloatArray(); // Read each row as a float array
+ }
+ }
+
+ /**
+ * Serializes the current state of this MultiBitScalarQuantizationState object into a byte array.
+ * This method uses the QuantizationStateSerializer to handle the serialization process.
+ *
+ * The serialized byte array includes all necessary state information, such as the thresholds
+ * and quantization parameters, ensuring that the object can be fully reconstructed from the byte array.
+ *
+ *
+ * {@code
+ * MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+ * byte[] serializedState = state.toByteArray();
+ * }
+ *
+ *
+ * @return a byte array representing the serialized state of this object.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
+ @Override
+ public byte[] toByteArray() throws IOException {
+ return QuantizationStateSerializer.serialize(this);
+ }
+
+ /**
+ * Deserializes a MultiBitScalarQuantizationState object from a byte array.
+ * This method uses the QuantizationStateSerializer to handle the deserialization process.
+ *
+ * The byte array should contain serialized state information, including the thresholds
+ * and quantization parameters, which are necessary to reconstruct the MultiBitScalarQuantizationState object.
+ *
+ *
+ * {@code
+ * byte[] serializedState = ...; // obtain the byte array from some source
+ * MultiBitScalarQuantizationState state = MultiBitScalarQuantizationState.fromByteArray(serializedState);
+ * }
+ *
+ *
+ * @param bytes the byte array containing the serialized state.
+ * @return the deserialized MultiBitScalarQuantizationState object.
+ * @throws IOException if an I/O error occurs during deserialization.
+ */
+ public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException {
+ return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, MultiBitScalarQuantizationState::new);
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java
new file mode 100644
index 000000000..9998b87e8
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java
@@ -0,0 +1,110 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import org.opensearch.Version;
+import org.opensearch.core.common.io.stream.StreamInput;
+import org.opensearch.core.common.io.stream.StreamOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+
+import java.io.IOException;
+
+/**
+ * OneBitScalarQuantizationState represents the state of one-bit scalar quantization,
+ * including the mean values used for quantization.
+ */
+@Getter
+@NoArgsConstructor // No-argument constructor for deserialization
+@AllArgsConstructor
+public final class OneBitScalarQuantizationState implements QuantizationState {
+ private ScalarQuantizationParams quantizationParams;
+ /**
+ * Mean thresholds used in the quantization process.
+ * Each threshold value corresponds to a dimension of the vector being quantized.
+ *
+ * Example:
+ * If we have a vector [1.2, 3.4, 5.6] and mean thresholds [2.0, 3.0, 4.0],
+ * The quantized vector will be [0, 1, 1].
+ */
+ private float[] meanThresholds;
+
+ @Override
+ public ScalarQuantizationParams getQuantizationParams() {
+ return quantizationParams;
+ }
+
+ /**
+ * This method is responsible for writing the state of the OneBitScalarQuantizationState object to an external output.
+ * It includes versioning information to ensure compatibility between different versions of the serialized object.
+ * @param out the StreamOutput to write the object to.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ out.writeVInt(Version.CURRENT.id); // Write the version
+ quantizationParams.writeTo(out);
+ out.writeFloatArray(meanThresholds);
+ }
+
+ /**
+ * This method is responsible for reading the state of the OneBitScalarQuantizationState object from an external input.
+ * It includes versioning information to ensure compatibility between different versions of the serialized object.
+ * @param in the StreamInput to read the object from.
+ * @throws IOException if an I/O error occurs during deserialization.
+ */
+ public OneBitScalarQuantizationState(StreamInput in) throws IOException {
+ int version = in.readVInt(); // Read the version
+ this.quantizationParams = new ScalarQuantizationParams(in, version);
+ this.meanThresholds = in.readFloatArray();
+ }
+
+ /**
+ * Serializes the current state of this OneBitScalarQuantizationState object into a byte array.
+ * This method uses the QuantizationStateSerializer to handle the serialization process.
+ *
+ * The serialized byte array includes all necessary state information, such as the mean thresholds
+ * and quantization parameters, ensuring that the object can be fully reconstructed from the byte array.
+ *
+ *
+ * {@code
+ * OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, meanThresholds);
+ * byte[] serializedState = state.toByteArray();
+ * }
+ *
+ *
+ * @return a byte array representing the serialized state of this object.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
+ @Override
+ public byte[] toByteArray() throws IOException {
+ return QuantizationStateSerializer.serialize(this);
+ }
+
+ /**
+ * Deserializes a OneBitScalarQuantizationState object from a byte array.
+ * This method uses the QuantizationStateSerializer to handle the deserialization process.
+ *
+ * The byte array should contain serialized state information, including the mean thresholds
+ * and quantization parameters, which are necessary to reconstruct the OneBitScalarQuantizationState object.
+ *
+ *
+ * {@code
+ * byte[] serializedState = ...; // obtain the byte array from some source
+ * OneBitScalarQuantizationState state = OneBitScalarQuantizationState.fromByteArray(serializedState);
+ * }
+ *
+ *
+ * @param bytes the byte array containing the serialized state.
+ * @return the deserialized OneBitScalarQuantizationState object.
+ * @throws IOException if an I/O error occurs during deserialization.
+ */
+ public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException {
+ return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, OneBitScalarQuantizationState::new);
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java
new file mode 100644
index 000000000..e32df8bc3
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java
@@ -0,0 +1,32 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import org.opensearch.core.common.io.stream.Writeable;
+import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
+
+import java.io.IOException;
+
+/**
+ * QuantizationState interface represents the state of a quantization process, including the parameters used.
+ * This interface provides methods for serializing and deserializing the state.
+ */
+public interface QuantizationState extends Writeable {
+ /**
+ * Returns the quantization parameters associated with this state.
+ *
+ * @return the quantization parameters.
+ */
+ QuantizationParams getQuantizationParams();
+
+ /**
+ * Serializes the quantization state to a byte array.
+ *
+ * @return a byte array representing the serialized state.
+ * @throws IOException if an I/O error occurs during serialization.
+ */
+ byte[] toByteArray() throws IOException;
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java
new file mode 100644
index 000000000..1f378e0dc
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java
@@ -0,0 +1,56 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.quantizationState;
+
+import lombok.experimental.UtilityClass;
+import org.opensearch.common.io.stream.BytesStreamOutput;
+import org.opensearch.core.common.io.stream.StreamInput;
+
+import java.io.IOException;
+
+/**
+ * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing
+ * QuantizationState objects along with their specific data.
+ */
+@UtilityClass
+class QuantizationStateSerializer {
+
+ /**
+ * A functional interface for deserializing specific data associated with a QuantizationState.
+ */
+ @FunctionalInterface
+ interface SerializableDeserializer {
+ QuantizationState deserialize(StreamInput in) throws IOException;
+ }
+
+ /**
+ * Serializes the QuantizationState and specific data into a byte array.
+ *
+ * @param state The QuantizationState to serialize.
+ * @return A byte array representing the serialized state and specific data.
+ * @throws IOException If an I/O error occurs during serialization.
+ */
+ static byte[] serialize(QuantizationState state) throws IOException {
+ try (BytesStreamOutput out = new BytesStreamOutput()) {
+ state.writeTo(out);
+ return out.bytes().toBytesRef().bytes;
+ }
+ }
+
+ /**
+ * Deserializes a QuantizationState and its specific data from a byte array.
+ *
+ * @param bytes The byte array containing the serialized data.
+ * @param deserializer The deserializer for the specific data associated with the state.
+ * @return The deserialized QuantizationState including its specific data.
+ * @throws IOException If an I/O error occurs during deserialization.
+ */
+ static QuantizationState deserialize(byte[] bytes, SerializableDeserializer deserializer) throws IOException {
+ try (StreamInput in = StreamInput.wrap(bytes)) {
+ return deserializer.deserialize(in);
+ }
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java
new file mode 100644
index 000000000..54ebe311c
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java
@@ -0,0 +1,31 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.models.requests;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+/**
+ * TrainingRequest represents a request for training a quantizer.
+ *
+ * @param the type of vectors to be trained.
+ */
+@Getter
+@AllArgsConstructor
+public abstract class TrainingRequest {
+ /**
+ * The total number of vectors in one segment.
+ */
+ private final int totalNumberOfVectors;
+
+ /**
+ * Returns the vector corresponding to the specified document ID.
+ *
+ * @param docId the document ID.
+ * @return the vector corresponding to the specified document ID.
+ */
+ public abstract T getVectorByDocId(int docId);
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java b/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java
new file mode 100644
index 000000000..fe470ed74
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java
@@ -0,0 +1,143 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import lombok.experimental.UtilityClass;
+
+/**
+ * The BitPacker class provides utility methods for quantizing floating-point vectors and packing the resulting bits
+ * into a pre-allocated byte array. This class supports both single-bit and multi-bit quantization scenarios,
+ * enabling efficient storage and transmission of quantized vectors.
+ *
+ *
+ * The methods in this class are designed to be used by quantizers that need to convert floating-point vectors
+ * into compact binary representations by comparing them against quantization thresholds.
+ *
+ *
+ *
+ * This class is marked as a utility class using Lombok's {@link lombok.experimental.UtilityClass} annotation,
+ * making it a singleton and preventing instantiation.
+ *
+ */
+@UtilityClass
+class BitPacker {
+
+ /**
+ * Quantizes a given floating-point vector and packs the resulting quantized bits into a provided byte array.
+ * This method operates by comparing each element of the input vector against corresponding thresholds
+ * and encoding the results into a compact binary format using the specified number of bits per coordinate.
+ *
+ *
+ * The method supports multi-bit quantization where each coordinate of the input vector can be represented
+ * by multiple bits. For example, with 2-bit quantization, each coordinate is encoded into 2 bits, allowing
+ * for four distinct levels of quantization per coordinate.
+ *
+ *
+ *
+ * Example:
+ *
+ *
+ * Consider a vector with 3 coordinates: [1.2, 3.4, 5.6]
and thresholds:
+ *
+ *
+ * thresholds = {
+ * {1.0, 3.0, 5.0}, // First bit thresholds
+ * {1.5, 3.5, 5.5} // Second bit thresholds
+ * };
+ *
+ *
+ * If the number of bits per coordinate is 2, the quantization process will proceed as follows:
+ *
+ *
+ * - First bit comparison:
+ *
+ * - 1.2 > 1.0 -> 1
+ * - 3.4 > 3.0 -> 1
+ * - 5.6 > 5.0 -> 1
+ *
+ *
+ * - Second bit comparison:
+ *
+ * - 1.2 <= 1.5 -> 0
+ * - 3.4 <= 3.5 -> 0
+ * - 5.6 > 5.5 -> 1
+ *
+ *
+ *
+ *
+ * The resulting quantized bits will be 11 10 11
, which is packed into the provided byte array.
+ * If there are fewer than 8 bits, the remaining bits in the byte are set to 0.
+ *
+ *
+ *
+ * Packing Process:
+ * The quantized bits are packed into the byte array. The first coordinate's bits are stored in the most
+ * significant positions of the first byte, followed by the second coordinate, and so on. In the example
+ * above, the resulting byte array will have the following binary representation:
+ *
+ *
+ * packedBits = [11011000] // Only the first 6 bits are used, and the last two are set to 0.
+ *
+ *
+ * Bitwise Operations Explanation:
+ *
+ * - byteIndex: This is calculated using
byteIndex = bitPosition >> 3
, which is equivalent to bitPosition / 8
. It determines which byte in the byte array the current bit should be placed in.
+ * - bitIndex: This is calculated using
bitIndex = 7 - (bitPosition & 7)
, which is equivalent to 7 - (bitPosition % 8)
. It determines the exact bit position within the byte.
+ * - Setting the bit: The bit is set using
packedBits[byteIndex] |= (1 << bitIndex)
. This shifts a 1 into the correct bit position and ORs it with the existing byte value to set the bit.
+ *
+ *
+ * @param vector the floating-point vector to be quantized.
+ * @param thresholds a 2D array representing the quantization thresholds. The first dimension corresponds to the number of bits per coordinate, and the second dimension corresponds to the vector's length.
+ * @param bitsPerCoordinate the number of bits used per coordinate, determining the granularity of the quantization.
+ * @param packedBits the byte array where the quantized bits will be packed.
+ */
+ void quantizeAndPackBits(final float[] vector, final float[][] thresholds, final int bitsPerCoordinate, byte[] packedBits) {
+ int vectorLength = vector.length;
+
+ for (int i = 0; i < bitsPerCoordinate; i++) {
+ for (int j = 0; j < vectorLength; j++) {
+ if (vector[j] > thresholds[i][j]) {
+ int bitPosition = i * vectorLength + j;
+ // Calculate the index of the byte in the packedBits array.
+ int byteIndex = bitPosition >> 3; // Equivalent to bitPosition / 8
+ // Calculate the bit index within the byte.
+ int bitIndex = 7 - (bitPosition & 7); // Equivalent to 7 - (bitPosition % 8)
+ // Set the bit at the calculated position.
+ packedBits[byteIndex] |= (1 << bitIndex); // Set the bit at bitIndex
+ }
+ }
+ }
+ }
+
+ /**
+ * Overloaded method to quantize a vector using single-bit quantization and pack the results into a provided byte array.
+ *
+ *
+ * This method is specifically designed for one-bit quantization scenarios, where each coordinate of the
+ * vector is represented by a single bit indicating whether the value is above or below the threshold.
+ *
+ *
+ * Example:
+ *
+ * If we have a vector [1.2, 3.4, 5.6]
and thresholds [2.0, 3.0, 4.0]
, the quantization process will be:
+ *
+ *
+ * - 1.2 < 2.0 -> 0
+ * - 3.4 > 3.0 -> 1
+ * - 5.6 > 4.0 -> 1
+ *
+ *
+ * The quantized vector will be [0, 1, 1]
.
+ *
+ *
+ * @param vector the vector to quantize.
+ * @param thresholds the thresholds for quantization, where each element represents the threshold for a corresponding coordinate.
+ * @param packedBits the byte array where the quantized bits will be packed.
+ */
+ void quantizeAndPackBits(final float[] vector, final float[] thresholds, byte[] packedBits) {
+ quantizeAndPackBits(vector, new float[][] { thresholds }, 1, packedBits);
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java
new file mode 100644
index 000000000..dcf825a6a
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java
@@ -0,0 +1,186 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import org.opensearch.knn.quantization.sampler.Sampler;
+import org.opensearch.knn.quantization.sampler.SamplerType;
+import org.opensearch.knn.quantization.sampler.SamplingFactory;
+
+/**
+ * MultiBitScalarQuantizer is responsible for quantizing vectors into multi-bit representations per dimension.
+ * Unlike the OneBitScalarQuantizer, which uses a single bit per dimension to represent whether a value is above
+ * or below a mean threshold, the MultiBitScalarQuantizer allows for multiple bits per dimension, enabling more
+ * granular and precise quantization.
+ *
+ *
+ * In a OneBitScalarQuantizer, each dimension of a vector is compared to a single threshold (the mean), and a single
+ * bit is used to indicate whether the value is above or below that threshold. This results in a very coarse
+ * representation where each dimension is either "on" or "off."
+ *
+ *
+ *
+ * The MultiBitScalarQuantizer, on the other hand, uses multiple thresholds per dimension. For example, in a 2-bit
+ * quantization scheme, three thresholds are used to divide each dimension into four possible regions. Each region
+ * is represented by a unique 2-bit value. This allows for a much finer representation of the data, capturing more
+ * nuances in the variation of each dimension.
+ *
+ *
+ *
+ * The thresholds in MultiBitScalarQuantizer are calculated based on the mean and standard deviation of the sampled
+ * vectors for each dimension. Here's how it works:
+ *
+ *
+ *
+ * - First, the mean and standard deviation are computed for each dimension across the sampled vectors.
+ * - For each bit used in the quantization (e.g., 2 bits per coordinate), the thresholds are calculated
+ * using a linear combination of the mean and the standard deviation. The combination coefficients are
+ * determined by the number of bits, allowing the thresholds to split the data into equal probability regions.
+ *
+ * - For example, in a 2-bit quantization (which divides data into four regions), the thresholds might be
+ * set at points corresponding to -1 standard deviation, 0 standard deviations (mean), and +1 standard deviation.
+ * This ensures that the data is evenly split into four regions, each represented by a 2-bit value.
+ *
+ *
+ *
+ *
+ * The number of bits per coordinate is determined by the type of scalar quantization being applied, such as 2-bit
+ * or 4-bit quantization. The increased number of bits per coordinate in MultiBitScalarQuantizer allows for better
+ * preservation of information during the quantization process, making it more suitable for tasks where precision
+ * is crucial. However, this comes at the cost of increased storage and computational complexity compared to the
+ * simpler OneBitScalarQuantizer.
+ *
+ */
+public class MultiBitScalarQuantizer implements Quantizer {
+ private final int bitsPerCoordinate; // Number of bits used to quantize each dimension
+ private final int samplingSize; // Sampling size for training
+ private final Sampler sampler; // Sampler for training
+ private static final boolean IS_TRAINING_REQUIRED = true;
+ // Currently Lucene has sampling size as
+ // 25000 for segment level training , Keeping same
+ // to having consistent, Will revisit
+ // if this requires change
+ private static final int DEFAULT_SAMPLE_SIZE = 25000;
+
+ /**
+ * Constructs a MultiBitScalarQuantizer with a specified number of bits per coordinate.
+ *
+ * @param bitsPerCoordinate the number of bits used per coordinate for quantization.
+ */
+ public MultiBitScalarQuantizer(final int bitsPerCoordinate) {
+ this(bitsPerCoordinate, DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplerType.RESERVOIR));
+ }
+
+ /**
+ * Constructs a MultiBitScalarQuantizer with a specified number of bits per coordinate and sampling size.
+ *
+ * @param bitsPerCoordinate the number of bits used per coordinate for quantization.
+ * @param samplingSize the number of samples to use for training.
+ * @param sampler the sampler to use for training.
+ */
+ public MultiBitScalarQuantizer(final int bitsPerCoordinate, final int samplingSize, final Sampler sampler) {
+ if (bitsPerCoordinate < 2) {
+ throw new IllegalArgumentException("bitsPerCoordinate must be greater than or equal to 2 for multibit quantizer.");
+ }
+ this.bitsPerCoordinate = bitsPerCoordinate;
+ this.samplingSize = samplingSize;
+ this.sampler = sampler;
+ }
+
+ /**
+ * Trains the quantizer based on the provided training request, which should be of type SamplingTrainingRequest.
+ * The training process calculates the mean and standard deviation for each dimension and then determines
+ * threshold values for quantization based on these statistics.
+ *
+ * @param trainingRequest the request containing the data and parameters for training.
+ * @return a MultiBitScalarQuantizationState containing the computed thresholds.
+ */
+ @Override
+ public QuantizationState train(final TrainingRequest trainingRequest) {
+ int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
+ int dimension = trainingRequest.getVectorByDocId(sampledIndices[0]).length;
+ float[] meanArray = new float[dimension];
+ float[] stdDevArray = new float[dimension];
+ // Calculate sum, mean, and standard deviation in one pass
+ QuantizerHelper.calculateMeanAndStdDev(trainingRequest, sampledIndices, meanArray, stdDevArray);
+ float[][] thresholds = calculateThresholds(meanArray, stdDevArray, dimension);
+ ScalarQuantizationParams params = (bitsPerCoordinate == 2)
+ ? new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT)
+ : new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
+ return new MultiBitScalarQuantizationState(params, thresholds);
+ }
+
+ /**
+ * Quantizes the provided vector using the provided quantization state, producing a quantized output.
+ * The vector is quantized based on the thresholds in the quantization state.
+ *
+ * @param vector the vector to quantize.
+ * @param state the quantization state containing threshold information.
+ * @param output the QuantizationOutput object to store the quantized representation of the vector.
+ */
+ @Override
+ public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput output) {
+ if (vector == null) {
+ throw new IllegalArgumentException("Vector to quantize must not be null.");
+ }
+ validateState(state);
+ MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) state;
+ float[][] thresholds = multiBitState.getThresholds();
+ if (thresholds == null || thresholds[0].length != vector.length) {
+ throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
+ }
+ // Prepare and get the writable array
+ byte[] writableArray = output.prepareAndGetWritableQuantizedVector(bitsPerCoordinate, vector.length);
+ BitPacker.quantizeAndPackBits(vector, thresholds, bitsPerCoordinate, writableArray);
+ }
+
+ /**
+ * Calculates the thresholds for quantization based on mean and standard deviation.
+ *
+ * @param meanArray the mean for each dimension.
+ * @param stdDevArray the standard deviation for each dimension.
+ * @param dimension the number of dimensions in the vectors.
+ * @return the thresholds for quantization.
+ */
+ private float[][] calculateThresholds(final float[] meanArray, final float[] stdDevArray, final int dimension) {
+ float[][] thresholds = new float[bitsPerCoordinate][dimension];
+ float coef = bitsPerCoordinate + 1;
+ for (int i = 0; i < bitsPerCoordinate; i++) {
+ float iCoef = -1 + 2 * (i + 1) / coef;
+ for (int j = 0; j < dimension; j++) {
+ thresholds[i][j] = meanArray[j] + iCoef * stdDevArray[j];
+ }
+ }
+ return thresholds;
+ }
+
+ /**
+ * Validates the quantization state to ensure it is of the expected type.
+ *
+ * @param state the quantization state to validate.
+ * @throws IllegalArgumentException if the state is not an instance of MultiBitScalarQuantizationState.
+ */
+ private void validateState(final QuantizationState state) {
+ if (!(state instanceof MultiBitScalarQuantizationState)) {
+ throw new IllegalArgumentException("Quantization state must be of type MultiBitScalarQuantizationState.");
+ }
+ }
+
+ /**
+ * Returns the number of bits per coordinate used by this quantizer.
+ *
+ * @return the number of bits per coordinate.
+ */
+ public int getBitsPerCoordinate() {
+ return bitsPerCoordinate;
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java
new file mode 100644
index 000000000..41602dfd2
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java
@@ -0,0 +1,100 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import org.opensearch.knn.quantization.sampler.Sampler;
+import org.opensearch.knn.quantization.sampler.SamplerType;
+import org.opensearch.knn.quantization.sampler.SamplingFactory;
+
+/**
+ * OneBitScalarQuantizer is responsible for quantizing vectors using a single bit per dimension.
+ * It computes the mean of each dimension during training and then uses these means as thresholds
+ * for quantizing the vectors.
+ */
+public class OneBitScalarQuantizer implements Quantizer {
+ private final int samplingSize; // Sampling size for training
+ private static final boolean IS_TRAINING_REQUIRED = true;
+ private final Sampler sampler; // Sampler for training
+ // Currently Lucene has sampling size as
+ // 25000 for segment level training , Keeping same
+ // to having consistent, Will revisit
+ // if this requires change
+ private static final int DEFAULT_SAMPLE_SIZE = 25000;
+
+ /**
+ * Constructs a OneBitScalarQuantizer with a default sampling size of 25000.
+ */
+ public OneBitScalarQuantizer() {
+ this(DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplerType.RESERVOIR));
+ }
+
+ /**
+ * Constructs a OneBitScalarQuantizer with a specified sampling size.
+ *
+ * @param samplingSize the number of samples to use for training.
+ */
+ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) {
+
+ this.samplingSize = samplingSize;
+ this.sampler = sampler;
+ }
+
+ /**
+ * Trains the quantizer by calculating the mean of each dimension from the sampled vectors.
+ * These means are used as thresholds in the quantization process.
+ *
+ * @param trainingRequest the request containing the data and parameters for training.
+ * @return a OneBitScalarQuantizationState containing the calculated means.
+ */
+ @Override
+ public QuantizationState train(final TrainingRequest trainingRequest) {
+ int[] sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
+ float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(trainingRequest, sampledDocIds);
+ return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds);
+ }
+
+ /**
+ * Quantizes the provided vector using the given quantization state.
+ * It compares each dimension of the vector against the corresponding mean (threshold) to determine the quantized value.
+ *
+ * @param vector the vector to quantize.
+ * @param state the quantization state containing the means for each dimension.
+ * @param output the QuantizationOutput object to store the quantized representation of the vector.
+ */
+ @Override
+ public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput output) {
+ if (vector == null) {
+ throw new IllegalArgumentException("Vector to quantize must not be null.");
+ }
+ validateState(state);
+ OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState) state;
+ float[] thresholds = binaryState.getMeanThresholds();
+ if (thresholds == null || thresholds.length != vector.length) {
+ throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
+ }
+ // Prepare and get the writable array
+ byte[] writableArray = output.prepareAndGetWritableQuantizedVector(1, vector.length);
+ BitPacker.quantizeAndPackBits(vector, thresholds, writableArray);
+ }
+
+ /**
+ * Validates the quantization state to ensure it is of the expected type.
+ *
+ * @param state the quantization state to validate.
+ * @throws IllegalArgumentException if the state is not an instance of OneBitScalarQuantizationState.
+ */
+ private void validateState(final QuantizationState state) {
+ if (!(state instanceof OneBitScalarQuantizationState)) {
+ throw new IllegalArgumentException("Quantization state must be of type OneBitScalarQuantizationState.");
+ }
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java
new file mode 100644
index 000000000..c0b297f5d
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java
@@ -0,0 +1,40 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+
+/**
+ * The Quantizer interface defines the methods required for training and quantizing vectors
+ * in the context of K-Nearest Neighbors (KNN) and similar machine learning tasks.
+ * It supports training to determine quantization parameters and quantizing data vectors
+ * based on these parameters.
+ *
+ * @param The type of the vector or data to be quantized.
+ * @param The type of the quantized output, typically a compressed or encoded representation.
+ */
+public interface Quantizer {
+
+ /**
+ * Trains the quantizer based on the provided training request. The training process typically
+ * involves learning parameters that can be used to quantize vectors.
+ *
+ * @param trainingRequest the request containing data and parameters for training.
+ * @return a QuantizationState containing the learned parameters.
+ */
+ QuantizationState train(TrainingRequest trainingRequest);
+
+ /**
+ * Quantizes the provided vector using the specified quantization state.
+ *
+ * @param vector the vector to quantize.
+ * @param state the quantization state containing parameters for quantization.
+ * @param output the QuantizationOutput object to store the quantized representation of the vector.
+ */
+ void quantize(T vector, QuantizationState state, QuantizationOutput output);
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java
new file mode 100644
index 000000000..16f969973
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import lombok.experimental.UtilityClass;
+
+/**
+ * Utility class providing common methods for quantizer operations, such as parameter validation and
+ * extraction. This class is designed to be used with various quantizer implementations that require
+ * consistent handling of training requests and sampled indices.
+ */
+@UtilityClass
+class QuantizerHelper {
+ /**
+ * Calculates the mean vector from a set of sampled vectors.
+ *
+ * @param samplingRequest The {@link TrainingRequest} containing the dataset and methods to access vectors by their indices.
+ * @param sampledIndices An array of indices representing the sampled vectors to be used for mean calculation.
+ * @return A float array representing the mean vector of the sampled vectors.
+ * @throws IllegalArgumentException If any of the vectors at the sampled indices are null.
+ * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors.
+ */
+ static float[] calculateMeanThresholds(TrainingRequest samplingRequest, int[] sampledIndices) {
+ int totalSamples = sampledIndices.length;
+ float[] mean = null;
+ for (int docId : sampledIndices) {
+ float[] vector = samplingRequest.getVectorByDocId(docId);
+ if (vector == null) {
+ throw new IllegalArgumentException("Vector at sampled index " + docId + " is null.");
+ }
+ if (mean == null) {
+ mean = new float[vector.length];
+ }
+ for (int j = 0; j < vector.length; j++) {
+ mean[j] += vector[j];
+ }
+ }
+ if (mean == null) {
+ throw new IllegalStateException("Mean array should not be null after processing vectors.");
+ }
+ for (int j = 0; j < mean.length; j++) {
+ mean[j] /= totalSamples;
+ }
+ return mean;
+ }
+
+ /**
+ * Calculates the mean and StdDev per dimension for sampled vectors.
+ *
+ * @param trainingRequest the request containing the data and parameters for training.
+ * @param sampledIndices the indices of the sampled vectors.
+ * @param meanArray the array to store the sum and then the mean of each dimension.
+ * @param stdDevArray the array to store the sum of squares and then the standard deviation of each dimension.
+ */
+ static void calculateMeanAndStdDev(
+ TrainingRequest trainingRequest,
+ int[] sampledIndices,
+ float[] meanArray,
+ float[] stdDevArray
+ ) {
+ int totalSamples = sampledIndices.length;
+ int dimension = meanArray.length;
+ for (int docId : sampledIndices) {
+ float[] vector = trainingRequest.getVectorByDocId(docId);
+ if (vector == null) {
+ throw new IllegalArgumentException("Vector at sampled index " + docId + " is null.");
+ }
+ for (int j = 0; j < dimension; j++) {
+ meanArray[j] += vector[j];
+ stdDevArray[j] += vector[j] * vector[j];
+ }
+ }
+
+ // Calculate mean and standard deviation in one pass
+ for (int j = 0; j < dimension; j++) {
+ meanArray[j] = meanArray[j] / totalSamples;
+ stdDevArray[j] = (float) Math.sqrt((stdDevArray[j] / totalSamples) - (meanArray[j] * meanArray[j]));
+ }
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java
new file mode 100644
index 000000000..020efe54f
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+import lombok.NoArgsConstructor;
+
+import java.util.Arrays;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.stream.IntStream;
+
+/**
+ * ReservoirSampler implements the Sampler interface and provides a method for sampling
+ * a specified number of indices from a total number of vectors using the reservoir sampling algorithm.
+ * This algorithm is particularly useful for randomly sampling a subset of data from a larger set
+ * when the total size of the dataset is unknown or very large.
+ */
+@NoArgsConstructor
+final class ReservoirSampler implements Sampler {
+ /**
+ * Singleton instance holder.
+ */
+ private static ReservoirSampler instance;
+
+ /**
+ * Provides the singleton instance of ReservoirSampler.
+ *
+ * @return the singleton instance of ReservoirSampler.
+ */
+ public static synchronized ReservoirSampler getInstance() {
+ if (instance == null) {
+ instance = new ReservoirSampler();
+ }
+ return instance;
+ }
+
+ /**
+ * Samples indices from the range [0, totalNumberOfVectors).
+ * If the total number of vectors is less than or equal to the sample size, it returns all indices.
+ * Otherwise, it uses the reservoir sampling algorithm to select a random subset.
+ *
+ * @param totalNumberOfVectors the total number of vectors to sample from.
+ * @param sampleSize the number of indices to sample.
+ * @return an array of sampled indices.
+ */
+ @Override
+ public int[] sample(final int totalNumberOfVectors, final int sampleSize) {
+ if (totalNumberOfVectors <= sampleSize) {
+ return IntStream.range(0, totalNumberOfVectors).toArray();
+ }
+ return reservoirSampleIndices(totalNumberOfVectors, sampleSize);
+ }
+
+ /**
+ * Applies the reservoir sampling algorithm to select a random sample of indices.
+ * This method ensures that each index in the range [0, numVectors) has an equal probability
+ * of being included in the sample.
+ *
+ * Reservoir sampling is particularly useful for selecting a random sample from a large or unknown-sized dataset.
+ * For more information on the algorithm, see the following link:
+ * Reservoir Sampling - Wikipedia
+ *
+ * @param numVectors the total number of vectors.
+ * @param sampleSize the number of indices to sample.
+ * @return an array of sampled indices.
+ */
+ private int[] reservoirSampleIndices(final int numVectors, final int sampleSize) {
+ int[] indices = new int[sampleSize];
+
+ // Initialize the reservoir with the first sampleSize elements
+ for (int i = 0; i < sampleSize; i++) {
+ indices[i] = i;
+ }
+
+ // Replace elements with gradually decreasing probability
+ for (int i = sampleSize; i < numVectors; i++) {
+ int j = ThreadLocalRandom.current().nextInt(i + 1);
+ if (j < sampleSize) {
+ indices[j] = i;
+ }
+ }
+
+ // Sort the sampled indices
+ Arrays.sort(indices);
+
+ return indices;
+ }
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java
new file mode 100644
index 000000000..5ff385972
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java
@@ -0,0 +1,25 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+/**
+ * The Sampler interface defines the contract for sampling strategies
+ * used in various quantization processes. Implementations of this
+ * interface should provide specific strategies for selecting a sample
+ * from a given set of vectors.
+ */
+public interface Sampler {
+
+ /**
+ * Samples a subset of indices from the total number of vectors.
+ *
+ * @param totalNumberOfVectors the total number of vectors available.
+ * @param sampleSize the number of vectors to be sampled.
+ * @return an array of integers representing the indices of the sampled vectors.
+ * @throws IllegalArgumentException if the sample size is greater than the total number of vectors.
+ */
+ int[] sample(int totalNumberOfVectors, int sampleSize);
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java
new file mode 100644
index 000000000..cd9b301df
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java
@@ -0,0 +1,14 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+/**
+ * SamplerType is an enumeration of the different types of samplers that can be created by the factory.
+ */
+public enum SamplerType {
+ RESERVOIR, // Represents a reservoir sampling strategy
+ // Add more enum values here for additional sampler types
+}
diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java
new file mode 100644
index 000000000..80fe5bdae
--- /dev/null
+++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+import lombok.AccessLevel;
+import lombok.NoArgsConstructor;
+
+/**
+ * SamplingFactory is a factory class for creating instances of Sampler.
+ * It uses the factory design pattern to encapsulate the creation logic for different types of samplers.
+ */
+@NoArgsConstructor(access = AccessLevel.PRIVATE)
+public final class SamplingFactory {
+
+ /**
+ * Creates and returns a Sampler instance based on the specified SamplerType.
+ *
+ * @param samplerType the type of sampler to create.
+ * @return a Sampler instance.
+ * @throws IllegalArgumentException if the sampler type is not supported.
+ */
+ public static Sampler getSampler(final SamplerType samplerType) {
+ switch (samplerType) {
+ case RESERVOIR:
+ return ReservoirSampler.getInstance();
+ // Add more cases for different samplers here
+ default:
+ throw new IllegalArgumentException("Unsupported sampler type: " + samplerType);
+ }
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java
new file mode 100644
index 000000000..99621a0e5
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.enums;
+
+import org.opensearch.knn.KNNTestCase;
+
+import java.util.HashSet;
+import java.util.Set;
+
+public class ScalarQuantizationTypeTests extends KNNTestCase {
+ public void testSQTypesValues() {
+ ScalarQuantizationType[] expectedValues = {
+ ScalarQuantizationType.ONE_BIT,
+ ScalarQuantizationType.TWO_BIT,
+ ScalarQuantizationType.FOUR_BIT };
+ assertArrayEquals(expectedValues, ScalarQuantizationType.values());
+ }
+
+ public void testSQTypesValueOf() {
+ assertEquals(ScalarQuantizationType.ONE_BIT, ScalarQuantizationType.valueOf("ONE_BIT"));
+ assertEquals(ScalarQuantizationType.TWO_BIT, ScalarQuantizationType.valueOf("TWO_BIT"));
+ assertEquals(ScalarQuantizationType.FOUR_BIT, ScalarQuantizationType.valueOf("FOUR_BIT"));
+ }
+
+ public void testUniqueSQTypeValues() {
+ Set uniqueIds = new HashSet<>();
+ for (ScalarQuantizationType type : ScalarQuantizationType.values()) {
+ boolean added = uniqueIds.add(type.getId());
+ assertTrue("Duplicate value found: " + type.getId(), added);
+ }
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java
new file mode 100644
index 000000000..3474b7ec9
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import org.junit.Before;
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.Quantizer;
+
+import java.lang.reflect.Field;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+public class QuantizerFactoryTests extends KNNTestCase {
+
+ @Before
+ public void resetIsRegisteredFlag() throws NoSuchFieldException, IllegalAccessException {
+ Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered");
+ isRegisteredField.setAccessible(true);
+ AtomicBoolean isRegistered = (AtomicBoolean) isRegisteredField.get(null);
+ isRegistered.set(false);
+ }
+
+ public void test_Lazy_Registration() {
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ ScalarQuantizationParams paramsTwoBit = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
+ ScalarQuantizationParams paramsFourBit = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
+ assertFalse(isRegisteredFieldAccessible());
+ Quantizer, ?> quantizer = QuantizerFactory.getQuantizer(params);
+ Quantizer, ?> quantizerTwoBit = QuantizerFactory.getQuantizer(paramsTwoBit);
+ Quantizer, ?> quantizerFourBit = QuantizerFactory.getQuantizer(paramsFourBit);
+ assertTrue(quantizerFourBit instanceof MultiBitScalarQuantizer);
+ assertTrue(quantizerTwoBit instanceof MultiBitScalarQuantizer);
+ assertTrue(quantizer instanceof OneBitScalarQuantizer);
+ assertTrue(isRegisteredFieldAccessible());
+ }
+
+ public void testGetQuantizer_withNullParams() {
+ try {
+ QuantizerFactory.getQuantizer(null);
+ fail("Expected IllegalArgumentException");
+ } catch (IllegalArgumentException e) {
+ assertEquals("Quantization parameters must not be null.", e.getMessage());
+ }
+ }
+
+ private boolean isRegisteredFieldAccessible() {
+ try {
+ Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered");
+ isRegisteredField.setAccessible(true);
+ AtomicBoolean isRegistered = (AtomicBoolean) isRegisteredField.get(null);
+ return isRegistered.get();
+ } catch (NoSuchFieldException | IllegalAccessException e) {
+ fail("Failed to access isRegistered field.");
+ return false;
+ }
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java
new file mode 100644
index 000000000..dec34e632
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java
@@ -0,0 +1,84 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.factory;
+
+import org.junit.BeforeClass;
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer;
+import org.opensearch.knn.quantization.quantizer.Quantizer;
+
+public class QuantizerRegistryTests extends KNNTestCase {
+
+ @BeforeClass
+ public static void setup() {
+ QuantizerRegistry.register(
+ ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT),
+ new OneBitScalarQuantizer()
+ );
+ QuantizerRegistry.register(
+ ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.TWO_BIT),
+ new MultiBitScalarQuantizer(2)
+ );
+ QuantizerRegistry.register(
+ ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT),
+ new MultiBitScalarQuantizer(4)
+ );
+ }
+
+ public void testRegisterAndGetQuantizer() {
+ // Test for OneBitScalarQuantizer
+ ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ Quantizer, ?> oneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
+ assertTrue(oneBitQuantizer instanceof OneBitScalarQuantizer);
+
+ // Test for MultiBitScalarQuantizer (2-bit)
+ ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
+ Quantizer, ?> twoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
+ assertTrue(twoBitQuantizer instanceof MultiBitScalarQuantizer);
+ assertEquals(2, ((MultiBitScalarQuantizer) twoBitQuantizer).getBitsPerCoordinate());
+
+ // Test for MultiBitScalarQuantizer (4-bit)
+ ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
+ Quantizer, ?> fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
+ assertTrue(fourBitQuantizer instanceof MultiBitScalarQuantizer);
+ assertEquals(4, ((MultiBitScalarQuantizer) fourBitQuantizer).getBitsPerCoordinate());
+ }
+
+ public void testQuantizerRegistryIsSingleton() {
+ // Ensure the same instance is returned for the same type identifier
+ ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ Quantizer, ?> firstOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
+ Quantizer, ?> secondOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams);
+ assertSame(firstOneBitQuantizer, secondOneBitQuantizer);
+
+ // Ensure the same instance is returned for the same type identifier (2-bit)
+ ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
+ Quantizer, ?> firstTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
+ Quantizer, ?> secondTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams);
+ assertSame(firstTwoBitQuantizer, secondTwoBitQuantizer);
+
+ // Ensure the same instance is returned for the same type identifier (4-bit)
+ ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
+ Quantizer, ?> firstFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
+ Quantizer, ?> secondFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams);
+ assertSame(firstFourBitQuantizer, secondFourBitQuantizer);
+ }
+
+ public void testRegisterQuantizerThrowsExceptionWhenAlreadyRegistered() {
+ ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+
+ // Attempt to register the same quantizer again should throw an exception
+ assertThrows(IllegalArgumentException.class, () -> {
+ QuantizerRegistry.register(
+ ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT),
+ new OneBitScalarQuantizer()
+ );
+ });
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java
new file mode 100644
index 000000000..fa25e8e80
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizationState;
+
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+
+import java.io.IOException;
+
+public class QuantizationStateSerializerTests extends KNNTestCase {
+
+ public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException {
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ float[] mean = new float[] { 0.1f, 0.2f, 0.3f };
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean);
+
+ // Serialize
+ byte[] serialized = state.toByteArray();
+
+ OneBitScalarQuantizationState deserialized = OneBitScalarQuantizationState.fromByteArray(serialized);
+
+ assertArrayEquals(mean, deserialized.getMeanThresholds(), 0.0f);
+ assertEquals(params, deserialized.getQuantizationParams());
+ }
+
+ public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws IOException {
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
+ float[][] thresholds = new float[][] { { 0.1f, 0.2f, 0.3f }, { 0.4f, 0.5f, 0.6f } };
+ MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+
+ // Serialize
+ byte[] serialized = state.toByteArray();
+ MultiBitScalarQuantizationState deserialized = MultiBitScalarQuantizationState.fromByteArray(serialized);
+
+ for (int i = 0; i < thresholds.length; i++) {
+ assertArrayEquals(thresholds[i], deserialized.getThresholds()[i], 0.0f);
+ }
+ assertEquals(params, deserialized.getQuantizationParams());
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java
new file mode 100644
index 000000000..35edf49e2
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java
@@ -0,0 +1,68 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizationState;
+
+import org.opensearch.core.common.io.stream.StreamInput;
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+
+import java.io.IOException;
+
+public class QuantizationStateTests extends KNNTestCase {
+
+ public void testOneBitScalarQuantizationStateSerialization() throws IOException {
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ float[] mean = { 1.0f, 2.0f, 3.0f };
+
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean);
+
+ // Serialize
+ byte[] serializedState = state.toByteArray();
+
+ // Deserialize
+ StreamInput in = StreamInput.wrap(serializedState);
+ OneBitScalarQuantizationState deserializedState = new OneBitScalarQuantizationState(in);
+
+ float delta = 0.0001f;
+ assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta);
+ assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType());
+ }
+
+ public void testMultiBitScalarQuantizationStateSerialization() throws IOException {
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
+ float[][] thresholds = { { 0.5f, 1.5f, 2.5f }, { 1.0f, 2.0f, 3.0f } };
+
+ MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+ byte[] serializedState = state.toByteArray();
+
+ // Deserialize
+ StreamInput in = StreamInput.wrap(serializedState);
+ MultiBitScalarQuantizationState deserializedState = new MultiBitScalarQuantizationState(in);
+
+ float delta = 0.0001f;
+ for (int i = 0; i < thresholds.length; i++) {
+ assertArrayEquals(thresholds[i], deserializedState.getThresholds()[i], delta);
+ }
+ assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType());
+ }
+
+ public void testSerializationWithDifferentVersions() throws IOException {
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ float[] mean = { 1.0f, 2.0f, 3.0f };
+
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean);
+ byte[] serializedState = state.toByteArray();
+ StreamInput in = StreamInput.wrap(serializedState);
+ OneBitScalarQuantizationState deserializedState = new OneBitScalarQuantizationState(in);
+
+ float delta = 0.0001f;
+ assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta);
+ assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType());
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java
new file mode 100644
index 000000000..ad6a44686
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java
@@ -0,0 +1,107 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+
+import java.io.IOException;
+
+public class MultiBitScalarQuantizerTests extends KNNTestCase {
+
+ public void testTrain_twoBit() {
+ float[][] vectors = {
+ { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f },
+ { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f },
+ { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } };
+ MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
+ TrainingRequest request = new MockTrainingRequest(params, vectors);
+ QuantizationState state = twoBitQuantizer.train(request);
+
+ assertTrue(state instanceof MultiBitScalarQuantizationState);
+ MultiBitScalarQuantizationState mbState = (MultiBitScalarQuantizationState) state;
+ assertNotNull(mbState.getThresholds());
+ assertEquals(2, mbState.getThresholds().length); // 2-bit quantization should have 2 thresholds
+ }
+
+ public void testTrain_fourBit() {
+ MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4);
+ float[][] vectors = {
+ { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f },
+ { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f },
+ { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } };
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
+ TrainingRequest request = new MockTrainingRequest(params, vectors);
+ QuantizationState state = fourBitQuantizer.train(request);
+
+ assertTrue(state instanceof MultiBitScalarQuantizationState);
+ MultiBitScalarQuantizationState mbState = (MultiBitScalarQuantizationState) state;
+ assertNotNull(mbState.getThresholds());
+ assertEquals(4, mbState.getThresholds().length); // 4-bit quantization should have 4 thresholds
+ }
+
+ public void testQuantize_twoBit() throws IOException {
+ MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
+ float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f };
+ float[][] thresholds = { { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } };
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
+ MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ twoBitQuantizer.quantize(vector, state, output);
+ assertNotNull(output.getQuantizedVector());
+ }
+
+ public void testQuantize_fourBit() throws IOException {
+ MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4);
+ float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f };
+ float[][] thresholds = {
+ { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f },
+ { 1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f, 8.1f },
+ { 1.2f, 2.2f, 3.2f, 4.2f, 5.2f, 6.2f, 7.2f, 8.2f },
+ { 1.3f, 2.3f, 3.3f, 4.3f, 5.3f, 6.3f, 7.3f, 8.3f } };
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
+ MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ fourBitQuantizer.quantize(vector, state, output);
+ assertNotNull(output.getQuantizedVector());
+ }
+
+ public void testQuantize_withNullVector() throws IOException {
+ MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ expectThrows(
+ IllegalArgumentException.class,
+ () -> twoBitQuantizer.quantize(
+ null,
+ new MultiBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT), new float[2][8]),
+ output
+ )
+ );
+ }
+
+ // Mock classes for testing
+ private static class MockTrainingRequest extends TrainingRequest {
+ private final float[][] vectors;
+
+ public MockTrainingRequest(ScalarQuantizationParams params, float[][] vectors) {
+ super(vectors.length);
+ this.vectors = vectors;
+ }
+
+ @Override
+ public float[] getVectorByDocId(int docId) {
+ return vectors[docId];
+ }
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java
new file mode 100644
index 000000000..28be260d7
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java
@@ -0,0 +1,136 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.quantizer;
+
+import org.opensearch.core.common.io.stream.StreamOutput;
+import org.opensearch.knn.KNNTestCase;
+import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
+import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput;
+import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
+import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
+import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
+import org.opensearch.knn.quantization.models.requests.TrainingRequest;
+import org.opensearch.knn.quantization.sampler.Sampler;
+import org.opensearch.knn.quantization.sampler.SamplerType;
+import org.opensearch.knn.quantization.sampler.SamplingFactory;
+
+import java.io.IOException;
+
+public class OneBitScalarQuantizerTests extends KNNTestCase {
+
+ public void testTrain_withTrainingRequired() {
+ float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } };
+
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ TrainingRequest originalRequest = new TrainingRequest(vectors.length) {
+ @Override
+ public float[] getVectorByDocId(int docId) {
+ return vectors[docId];
+ }
+ };
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ QuantizationState state = quantizer.train(originalRequest);
+
+ assertTrue(state instanceof OneBitScalarQuantizationState);
+ float[] meanThresholds = ((OneBitScalarQuantizationState) state).getMeanThresholds();
+ assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, meanThresholds, 0.001f);
+ }
+
+ public void testQuantize_withState() throws IOException {
+ float[] vector = { 3.0f, 6.0f, 9.0f };
+ float[] thresholds = { 4.0f, 5.0f, 6.0f };
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(
+ new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT),
+ thresholds
+ );
+
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ quantizer.quantize(vector, state, output);
+
+ assertNotNull(output);
+ byte[] expectedPackedBits = new byte[] { 0b01100000 }; // or 96 in decimal
+ assertArrayEquals(expectedPackedBits, output.getQuantizedVector());
+ }
+
+ public void testQuantize_withNullVector() throws IOException {
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(
+ new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT),
+ new float[] { 0.0f }
+ );
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state, output));
+ }
+
+ public void testQuantize_withInvalidState() throws IOException {
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ float[] vector = { 1.0f, 2.0f, 3.0f };
+ QuantizationState invalidState = new QuantizationState() {
+ @Override
+ public ScalarQuantizationParams getQuantizationParams() {
+ return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ return new byte[0];
+ }
+
+ @Override
+ public void writeTo(StreamOutput out) throws IOException {
+ // Empty implementation for test
+ }
+ };
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, invalidState, output));
+ }
+
+ public void testQuantize_withMismatchedDimensions() throws IOException {
+ OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer();
+ float[] vector = { 1.0f, 2.0f, 3.0f };
+ float[] thresholds = { 4.0f, 5.0f };
+ OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(
+ new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT),
+ thresholds
+ );
+ BinaryQuantizationOutput output = new BinaryQuantizationOutput();
+ expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state, output));
+ }
+
+ public void testCalculateMean() {
+ float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } };
+
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ TrainingRequest samplingRequest = new TrainingRequest(vectors.length) {
+ @Override
+ public float[] getVectorByDocId(int docId) {
+ return vectors[docId];
+ }
+ };
+
+ Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR);
+ int[] sampledIndices = sampler.sample(vectors.length, 3);
+ float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices);
+ assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, meanThresholds, 0.001f);
+ }
+
+ public void testCalculateMean_withNullVector() {
+ float[][] vectors = { { 1.0f, 2.0f, 3.0f }, null, { 7.0f, 8.0f, 9.0f } };
+
+ ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
+ TrainingRequest samplingRequest = new TrainingRequest(vectors.length) {
+ @Override
+ public float[] getVectorByDocId(int docId) {
+ return vectors[docId];
+ }
+ };
+
+ Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR);
+ int[] sampledIndices = sampler.sample(vectors.length, 3);
+ expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices));
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java
new file mode 100644
index 000000000..59952eb10
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java
@@ -0,0 +1,63 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+import org.opensearch.knn.KNNTestCase;
+
+import java.util.Arrays;
+import java.util.stream.IntStream;
+
+public class ReservoirSamplerTests extends KNNTestCase {
+
+ public void testSampleLessThanSampleSize() {
+ ReservoirSampler sampler = ReservoirSampler.getInstance();
+ int totalNumberOfVectors = 5;
+ int sampleSize = 10;
+ int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray();
+ assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices);
+ }
+
+ public void testSampleEqualToSampleSize() {
+ ReservoirSampler sampler = ReservoirSampler.getInstance();
+ int totalNumberOfVectors = 10;
+ int sampleSize = 10;
+ int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray();
+ assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices);
+ }
+
+ public void testSampleRandomness() {
+ ReservoirSampler sampler1 = ReservoirSampler.getInstance();
+ ReservoirSampler sampler2 = ReservoirSampler.getInstance();
+ int totalNumberOfVectors = 100;
+ int sampleSize = 10;
+
+ int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize);
+ int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize);
+
+ // It's unlikely but possible for the two samples to be equal, so we just check they are sorted correctly
+ Arrays.sort(sampledIndices1);
+ Arrays.sort(sampledIndices2);
+ assertFalse("Sampled indices should be different", Arrays.equals(sampledIndices1, sampledIndices2));
+ }
+
+ public void testEdgeCaseZeroVectors() {
+ ReservoirSampler sampler = ReservoirSampler.getInstance();
+ int totalNumberOfVectors = 0;
+ int sampleSize = 10;
+ int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ assertEquals("Sampled indices should be empty when there are zero vectors.", 0, sampledIndices.length);
+ }
+
+ public void testEdgeCaseZeroSampleSize() {
+ ReservoirSampler sampler = ReservoirSampler.getInstance();
+ int totalNumberOfVectors = 10;
+ int sampleSize = 0;
+ int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize);
+ assertEquals("Sampled indices should be empty when sample size is zero.", 0, sampledIndices.length);
+ }
+}
diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java
new file mode 100644
index 000000000..db8772b70
--- /dev/null
+++ b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java
@@ -0,0 +1,19 @@
+/*
+ * Copyright OpenSearch Contributors
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package org.opensearch.knn.quantization.sampler;
+
+import org.opensearch.knn.KNNTestCase;
+
+public class SamplingFactoryTests extends KNNTestCase {
+ public void testGetSampler_withReservoir() {
+ Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR);
+ assertTrue(sampler instanceof ReservoirSampler);
+ }
+
+ public void testGetSampler_withUnsupportedType() {
+ expectThrows(NullPointerException.class, () -> SamplingFactory.getSampler(null)); // This should throw an exception
+ }
+}