From 8eb3e19494f5aeac4aa9cbb5f38c7f83d582b830 Mon Sep 17 00:00:00 2001 From: Vikasht34 Date: Tue, 13 Aug 2024 12:45:53 -0700 Subject: [PATCH] Quantization Framework Implementation with 1bit and MultiBit Binary Quantizer (#1929) * Quantization Framework Implementation with 1bit and MultiBit Binary Quantizer Signed-off-by: VIKASH TIWARI * Quantization Framework Implementation with 1bit and MultiBit Binary Quantizer Signed-off-by: VIKASH TIWARI * Implemented Serlization using Writable Signed-off-by: VIKASH TIWARI --------- Signed-off-by: VIKASH TIWARI Signed-off-by: Vikasht34 --- CHANGELOG.md | 1 + .../enums/ScalarQuantizationType.java | 62 ++++++ .../factory/QuantizerFactory.java | 54 +++++ .../factory/QuantizerRegistrar.java | 46 +++++ .../factory/QuantizerRegistry.java | 59 ++++++ .../BinaryQuantizationOutput.java | 67 +++++++ .../QuantizationOutput.java | 28 +++ .../QuantizationParams.java | 27 +++ .../ScalarQuantizationParams.java | 77 ++++++++ .../DefaultQuantizationState.java | 67 +++++++ .../MultiBitScalarQuantizationState.java | 127 ++++++++++++ .../OneBitScalarQuantizationState.java | 110 +++++++++++ .../quantizationState/QuantizationState.java | 32 +++ .../QuantizationStateSerializer.java | 56 ++++++ .../models/requests/TrainingRequest.java | 31 +++ .../knn/quantization/quantizer/BitPacker.java | 143 ++++++++++++++ .../quantizer/MultiBitScalarQuantizer.java | 186 ++++++++++++++++++ .../quantizer/OneBitScalarQuantizer.java | 100 ++++++++++ .../knn/quantization/quantizer/Quantizer.java | 40 ++++ .../quantizer/QuantizerHelper.java | 84 ++++++++ .../sampler/ReservoirSampler.java | 90 +++++++++ .../knn/quantization/sampler/Sampler.java | 25 +++ .../knn/quantization/sampler/SamplerType.java | 14 ++ .../quantization/sampler/SamplingFactory.java | 34 ++++ .../enums/ScalarQuantizationTypeTests.java | 35 ++++ .../factory/QuantizerFactoryTests.java | 63 ++++++ .../factory/QuantizerRegistryTests.java | 84 ++++++++ .../QuantizationStateSerializerTests.java | 46 +++++ .../QuantizationStateTests.java | 68 +++++++ .../MultiBitScalarQuantizerTests.java | 107 ++++++++++ .../quantizer/OneBitScalarQuantizerTests.java | 136 +++++++++++++ .../sampler/ReservoirSamplerTests.java | 63 ++++++ .../sampler/SamplingFactoryTests.java | 19 ++ 33 files changed, 2181 insertions(+) create mode 100644 src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java create mode 100644 src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java create mode 100644 src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java create mode 100644 src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java create mode 100644 src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java create mode 100644 src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java create mode 100644 src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java create mode 100644 src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java create mode 100644 src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java create mode 100644 src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java create mode 100644 src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java create mode 100644 src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java create mode 100644 src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java create mode 100644 src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java create mode 100644 src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index eb8427b1f..dfca7bece 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,3 +35,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) * Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877) * Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939) +* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889) diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java new file mode 100644 index 000000000..40347ad93 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +import lombok.Getter; + +/** + * The ScalarQuantizationType enum defines the various scalar quantization types that can be used + * for vector quantization. Each type corresponds to a different bit-width representation of the quantized values. + * + *

+ * Future Developers: If you change the name of any enum constant, do not change its associated value. + * Serialization and deserialization depend on these values to maintain compatibility. + *

+ */ +@Getter +public enum ScalarQuantizationType { + /** + * ONE_BIT quantization uses a single bit per coordinate. + */ + ONE_BIT(1), + + /** + * TWO_BIT quantization uses two bits per coordinate. + */ + TWO_BIT(2), + + /** + * FOUR_BIT quantization uses four bits per coordinate. + */ + FOUR_BIT(4); + + private final int id; + + /** + * Constructs a ScalarQuantizationType with the specified ID. + * + * @param id the ID representing the quantization type. + */ + ScalarQuantizationType(int id) { + this.id = id; + } + + /** + * Returns the ScalarQuantizationType associated with the given ID. + * + * @param id the ID of the quantization type. + * @return the corresponding ScalarQuantizationType. + * @throws IllegalArgumentException if the ID does not correspond to any ScalarQuantizationType. + */ + public static ScalarQuantizationType fromId(int id) { + for (ScalarQuantizationType type : ScalarQuantizationType.values()) { + if (type.getId() == id) { + return type; + } + } + throw new IllegalArgumentException("Unknown ScalarQuantizationType ID: " + id); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java new file mode 100644 index 000000000..b99f6ebdc --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * The QuantizerFactory class is responsible for creating instances of {@link Quantizer} + * based on the provided {@link QuantizationParams}. It uses a registry to look up the + * appropriate quantizer implementation for the given quantization parameters. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class QuantizerFactory { + private static final AtomicBoolean isRegistered = new AtomicBoolean(false); + + /** + * Ensures that default quantizers are registered. + */ + private static void ensureRegistered() { + if (!isRegistered.get()) { + synchronized (QuantizerFactory.class) { + if (!isRegistered.get()) { + QuantizerRegistrar.registerDefaultQuantizers(); + isRegistered.set(true); + } + } + } + } + + /** + * Retrieves a quantizer instance based on the provided quantization parameters. + * + * @param params the quantization parameters used to determine the appropriate quantizer + * @param

the type of quantization parameters, extending {@link QuantizationParams} + * @param the type of the quantized output + * @return an instance of {@link Quantizer} corresponding to the provided parameters + */ + public static

Quantizer getQuantizer(final P params) { + if (params == null) { + throw new IllegalArgumentException("Quantization parameters must not be null."); + } + // Lazy Registration instead of static block as class level; + ensureRegistered(); + return QuantizerRegistry.getQuantizer(params); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java new file mode 100644 index 000000000..7b542aea0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; + +/** + * The QuantizerRegistrar class is responsible for registering default quantizers. + * This class ensures that the registration happens only once in a thread-safe manner. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +final class QuantizerRegistrar { + + /** + * Registers default quantizers + *

+ * This method is synchronized to ensure that registration occurs only once, + * even in a multi-threaded environment. + *

+ */ + static synchronized void registerDefaultQuantizers() { + // Register OneBitScalarQuantizer for SQParams with VALUE_QUANTIZATION and SQTypes.ONE_BIT + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT), + new OneBitScalarQuantizer() + ); + // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 2 + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.TWO_BIT), + new MultiBitScalarQuantizer(2) + ); + // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 4 + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT), + new MultiBitScalarQuantizer(4) + ); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java new file mode 100644 index 000000000..ac266f547 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * The QuantizerRegistry class is responsible for managing the registration and retrieval + * of quantizer instances. Quantizers are registered with specific quantization parameters + * and type identifiers, allowing for efficient lookup and instantiation. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +final class QuantizerRegistry { + // ConcurrentHashMap for thread-safe access + private static final Map> registry = new ConcurrentHashMap<>(); + + /** + * Registers a quantizer with the registry. + * + * @param paramIdentifier the unique identifier for the quantization parameters + * @param quantizer an instance of the quantizer + */ + static void register(final String paramIdentifier, final Quantizer quantizer) { + // Check if the quantizer is already registered for the given identifier + if (registry.putIfAbsent(paramIdentifier, quantizer) != null) { + // Throw an exception if a quantizer is already registered + throw new IllegalArgumentException("Quantizer already registered for identifier: " + paramIdentifier); + } + } + + /** + * Retrieves a quantizer instance based on the provided quantization parameters. + * + * @param params the quantization parameters used to determine the appropriate quantizer + * @param

the type of quantization parameters + * @param the type of the quantized output + * @return an instance of {@link Quantizer} corresponding to the provided parameters + * @throws IllegalArgumentException if no quantizer is registered for the given parameters + */ + static

Quantizer getQuantizer(final P params) { + String identifier = params.getTypeIdentifier(); + Quantizer quantizer = registry.get(identifier); + if (quantizer == null) { + throw new IllegalArgumentException("No quantizer registered for type identifier: " + identifier); + } + @SuppressWarnings("unchecked") + Quantizer typedQuantizer = (Quantizer) quantizer; + return typedQuantizer; + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java new file mode 100644 index 000000000..95592fcb9 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationOutput; + +import lombok.Getter; +import lombok.NoArgsConstructor; + +import java.util.Arrays; + +/** + * The BinaryQuantizationOutput class represents the output of a quantization process in binary format. + * It implements the QuantizationOutput interface to handle byte arrays specifically. + */ +@NoArgsConstructor +public class BinaryQuantizationOutput implements QuantizationOutput { + @Getter + private byte[] quantizedVector; + + /** + * Prepares the quantized vector array based on the provided parameters and returns it for direct modification. + * This method ensures that the internal byte array is appropriately sized and cleared before being used. + * The method accepts two parameters: + *

    + *
  • bitsPerCoordinate: The number of bits used per coordinate. This determines the granularity of the quantization.
  • + *
  • vectorLength: The length of the original vector that needs to be quantized. This helps in calculating the required byte array size.
  • + *
+ * If the existing quantized vector is either null or not the same size as the required byte array, + * a new byte array is allocated. Otherwise, the existing array is cleared (i.e., all bytes are set to zero). + * This method is designed to be used in conjunction with a bit-packing utility that writes quantized values directly + * into the returned byte array. + * @param params an array of parameters, where the first parameter is the number of bits per coordinate (int), + * and the second parameter is the length of the vector (int). + * @return the prepared and writable quantized vector as a byte array. + * @throws IllegalArgumentException if the parameters are not as expected (e.g., missing or not integers). + */ + @Override + public byte[] prepareAndGetWritableQuantizedVector(Object... params) { + if (params.length != 2 || !(params[0] instanceof Integer) || !(params[1] instanceof Integer)) { + throw new IllegalArgumentException("Expected two integer parameters: bitsPerCoordinate and vectorLength"); + } + int bitsPerCoordinate = (int) params[0]; + int vectorLength = (int) params[1]; + int totalBits = bitsPerCoordinate * vectorLength; + int byteLength = (totalBits + 7) >> 3; + + if (this.quantizedVector == null || this.quantizedVector.length != byteLength) { + this.quantizedVector = new byte[byteLength]; + } else { + Arrays.fill(this.quantizedVector, (byte) 0); + } + + return this.quantizedVector; + } + + /** + * Returns the quantized vector. + * + * @return the quantized vector byte array. + */ + @Override + public byte[] getQuantizedVector() { + return quantizedVector; + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java new file mode 100644 index 000000000..aa81a8821 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationOutput; + +/** + * The QuantizationOutput interface defines the contract for quantization output data. + * + * @param The type of the quantized data. + */ +public interface QuantizationOutput { + /** + * Returns the quantized vector. + * + * @return the quantized data. + */ + T getQuantizedVector(); + + /** + * Prepares and returns the writable quantized vector for direct modification. + * + * @param params the parameters needed for preparing the quantized vector. + * @return the prepared and writable quantized vector. + */ + T prepareAndGetWritableQuantizedVector(Object... params); +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java new file mode 100644 index 000000000..4f2ee36c5 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationParams; + +import org.opensearch.core.common.io.stream.Writeable; + +/** + * Interface for quantization parameters. + * This interface defines the basic contract for all quantization parameter types. + * It provides methods to retrieve the quantization type and a unique type identifier. + * Implementations of this interface are expected to provide specific configurations + * for various quantization strategies. + */ +public interface QuantizationParams extends Writeable { + /** + * Provides a unique identifier for the quantization parameters. + * This identifier is typically a combination of the quantization type + * and additional specifics, and it serves to distinguish between different + * configurations or modes of quantization. + * + * @return a string representing the unique type identifier. + */ + String getTypeIdentifier(); +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java new file mode 100644 index 000000000..4e7a53892 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationParams; + +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; + +import java.io.IOException; + +/** + * The ScalarQuantizationParams class represents the parameters specific to scalar quantization (SQ). + * This class implements the QuantizationParams interface and includes the type of scalar quantization. + */ +@Getter +@AllArgsConstructor +@NoArgsConstructor // No-argument constructor for deserialization +@EqualsAndHashCode +public class ScalarQuantizationParams implements QuantizationParams { + private ScalarQuantizationType sqType; + + /** + * Static method to generate type identifier based on ScalarQuantizationType. + * + * @param sqType the scalar quantization type. + * @return A string representing the unique type identifier. + */ + public static String generateTypeIdentifier(ScalarQuantizationType sqType) { + return generateIdentifier(sqType.getId()); + } + + /** + * Provides a unique type identifier for the ScalarQuantizationParams, combining the SQ type. + * This identifier is useful for distinguishing between different configurations of scalar quantization parameters. + * + * @return A string representing the unique type identifier. + */ + @Override + public String getTypeIdentifier() { + return generateIdentifier(sqType.getId()); + } + + private static String generateIdentifier(int id) { + return "ScalarQuantizationParams_" + id; + } + + /** + * Writes the object to the output stream. + * This method is part of the Writeable interface and is used to serialize the object. + * + * @param out the output stream to write the object to. + * @throws IOException if an I/O error occurs. + */ + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(sqType.getId()); + } + + /** + * Reads the object from the input stream. + * This method is part of the Writeable interface and is used to deserialize the object. + * + * @param in the input stream to read the object from. + * @throws IOException if an I/O error occurs. + */ + public ScalarQuantizationParams(StreamInput in, int version) throws IOException { + int typeId = in.readVInt(); + this.sqType = ScalarQuantizationType.fromId(typeId); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java new file mode 100644 index 000000000..33e775cad --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.opensearch.Version; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; + +import java.io.IOException; + +/** + * DefaultQuantizationState is used as a fallback state when no training is required or if training fails. + * It can be utilized by any quantizer to represent a default state. + */ +@Getter +@NoArgsConstructor // No-argument constructor for deserialization +@AllArgsConstructor +public class DefaultQuantizationState implements QuantizationState { + private QuantizationParams params; + + @Override + public QuantizationParams getQuantizationParams() { + return params; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(Version.CURRENT.id); // Write the version + params.writeTo(out); + } + + public DefaultQuantizationState(StreamInput in) throws IOException { + int version = in.readInt(); // Read the version + this.params = new ScalarQuantizationParams(in, version); + } + + /** + * Serializes the quantization state to a byte array. + * + * @return a byte array representing the serialized state. + * @throws IOException if an I/O error occurs during serialization. + */ + @Override + public byte[] toByteArray() throws IOException { + return QuantizationStateSerializer.serialize(this); + } + + /** + * Deserializes a DefaultQuantizationState from a byte array. + * + * @param bytes the byte array containing the serialized state. + * @return the deserialized DefaultQuantizationState. + * @throws IOException if an I/O error occurs during deserialization. + * @throws ClassNotFoundException if the class of the serialized object cannot be found. + */ + public static DefaultQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { + return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(bytes, DefaultQuantizationState::new); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java new file mode 100644 index 000000000..09092fde8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java @@ -0,0 +1,127 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.opensearch.Version; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; + +import java.io.IOException; + +/** + * MultiBitScalarQuantizationState represents the state of multi-bit scalar quantization, + * including the thresholds used for quantization. + */ +@Getter +@NoArgsConstructor // No-argument constructor for deserialization +@AllArgsConstructor +public final class MultiBitScalarQuantizationState implements QuantizationState { + private ScalarQuantizationParams quantizationParams; + /** + * The threshold values for multi-bit quantization, organized as a 2D array + * where each row corresponds to a different bit level. + * + * For example: + * - For 2-bit quantization: + * thresholds[0] -> {0.5f, 1.5f, 2.5f} // Thresholds for the first bit level + * thresholds[1] -> {1.0f, 2.0f, 3.0f} // Thresholds for the second bit level + * - For 4-bit quantization: + * thresholds[0] -> {0.1f, 0.2f, 0.3f} + * thresholds[1] -> {0.4f, 0.5f, 0.6f} + * thresholds[2] -> {0.7f, 0.8f, 0.9f} + * thresholds[3] -> {1.0f, 1.1f, 1.2f} + * + * Each column represents the threshold for a specific dimension in the vector space. + */ + private float[][] thresholds; + + @Override + public ScalarQuantizationParams getQuantizationParams() { + return quantizationParams; + } + + /** + * This method is responsible for writing the state of the MultiBitScalarQuantizationState object to an external output. + * It includes versioning information to ensure compatibility between different versions of the serialized object. + * + * @param out the StreamOutput to write the object to. + * @throws IOException if an I/O error occurs during serialization. + */ + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(Version.CURRENT.id); // Write the version + quantizationParams.writeTo(out); + out.writeVInt(thresholds.length); // Number of rows + for (float[] row : thresholds) { + out.writeFloatArray(row); // Write each row as a float array + } + } + + /** + * This method is responsible for reading the state of the MultiBitScalarQuantizationState object from an external input. + * It includes versioning information to ensure compatibility between different versions of the serialized object. + * + * @param in the StreamInput to read the object from. + * @throws IOException if an I/O error occurs during deserialization. + */ + public MultiBitScalarQuantizationState(StreamInput in) throws IOException { + int version = in.readVInt(); // Read the version + this.quantizationParams = new ScalarQuantizationParams(in, version); + int rows = in.readVInt(); // Read the number of rows + this.thresholds = new float[rows][]; + for (int i = 0; i < rows; i++) { + this.thresholds[i] = in.readFloatArray(); // Read each row as a float array + } + } + + /** + * Serializes the current state of this MultiBitScalarQuantizationState object into a byte array. + * This method uses the QuantizationStateSerializer to handle the serialization process. + * + *

The serialized byte array includes all necessary state information, such as the thresholds + * and quantization parameters, ensuring that the object can be fully reconstructed from the byte array.

+ * + *
+     * {@code
+     * MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);
+     * byte[] serializedState = state.toByteArray();
+     * }
+     * 
+ * + * @return a byte array representing the serialized state of this object. + * @throws IOException if an I/O error occurs during serialization. + */ + @Override + public byte[] toByteArray() throws IOException { + return QuantizationStateSerializer.serialize(this); + } + + /** + * Deserializes a MultiBitScalarQuantizationState object from a byte array. + * This method uses the QuantizationStateSerializer to handle the deserialization process. + * + *

The byte array should contain serialized state information, including the thresholds + * and quantization parameters, which are necessary to reconstruct the MultiBitScalarQuantizationState object.

+ * + *
+     * {@code
+     * byte[] serializedState = ...; // obtain the byte array from some source
+     * MultiBitScalarQuantizationState state = MultiBitScalarQuantizationState.fromByteArray(serializedState);
+     * }
+     * 
+ * + * @param bytes the byte array containing the serialized state. + * @return the deserialized MultiBitScalarQuantizationState object. + * @throws IOException if an I/O error occurs during deserialization. + */ + public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException { + return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, MultiBitScalarQuantizationState::new); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java new file mode 100644 index 000000000..9998b87e8 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.opensearch.Version; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; + +import java.io.IOException; + +/** + * OneBitScalarQuantizationState represents the state of one-bit scalar quantization, + * including the mean values used for quantization. + */ +@Getter +@NoArgsConstructor // No-argument constructor for deserialization +@AllArgsConstructor +public final class OneBitScalarQuantizationState implements QuantizationState { + private ScalarQuantizationParams quantizationParams; + /** + * Mean thresholds used in the quantization process. + * Each threshold value corresponds to a dimension of the vector being quantized. + * + * Example: + * If we have a vector [1.2, 3.4, 5.6] and mean thresholds [2.0, 3.0, 4.0], + * The quantized vector will be [0, 1, 1]. + */ + private float[] meanThresholds; + + @Override + public ScalarQuantizationParams getQuantizationParams() { + return quantizationParams; + } + + /** + * This method is responsible for writing the state of the OneBitScalarQuantizationState object to an external output. + * It includes versioning information to ensure compatibility between different versions of the serialized object. + * @param out the StreamOutput to write the object to. + * @throws IOException if an I/O error occurs during serialization. + */ + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(Version.CURRENT.id); // Write the version + quantizationParams.writeTo(out); + out.writeFloatArray(meanThresholds); + } + + /** + * This method is responsible for reading the state of the OneBitScalarQuantizationState object from an external input. + * It includes versioning information to ensure compatibility between different versions of the serialized object. + * @param in the StreamInput to read the object from. + * @throws IOException if an I/O error occurs during deserialization. + */ + public OneBitScalarQuantizationState(StreamInput in) throws IOException { + int version = in.readVInt(); // Read the version + this.quantizationParams = new ScalarQuantizationParams(in, version); + this.meanThresholds = in.readFloatArray(); + } + + /** + * Serializes the current state of this OneBitScalarQuantizationState object into a byte array. + * This method uses the QuantizationStateSerializer to handle the serialization process. + * + *

The serialized byte array includes all necessary state information, such as the mean thresholds + * and quantization parameters, ensuring that the object can be fully reconstructed from the byte array.

+ * + *
+     * {@code
+     * OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, meanThresholds);
+     * byte[] serializedState = state.toByteArray();
+     * }
+     * 
+ * + * @return a byte array representing the serialized state of this object. + * @throws IOException if an I/O error occurs during serialization. + */ + @Override + public byte[] toByteArray() throws IOException { + return QuantizationStateSerializer.serialize(this); + } + + /** + * Deserializes a OneBitScalarQuantizationState object from a byte array. + * This method uses the QuantizationStateSerializer to handle the deserialization process. + * + *

The byte array should contain serialized state information, including the mean thresholds + * and quantization parameters, which are necessary to reconstruct the OneBitScalarQuantizationState object.

+ * + *
+     * {@code
+     * byte[] serializedState = ...; // obtain the byte array from some source
+     * OneBitScalarQuantizationState state = OneBitScalarQuantizationState.fromByteArray(serializedState);
+     * }
+     * 
+ * + * @param bytes the byte array containing the serialized state. + * @return the deserialized OneBitScalarQuantizationState object. + * @throws IOException if an I/O error occurs during deserialization. + */ + public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException { + return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, OneBitScalarQuantizationState::new); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java new file mode 100644 index 000000000..e32df8bc3 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; + +import java.io.IOException; + +/** + * QuantizationState interface represents the state of a quantization process, including the parameters used. + * This interface provides methods for serializing and deserializing the state. + */ +public interface QuantizationState extends Writeable { + /** + * Returns the quantization parameters associated with this state. + * + * @return the quantization parameters. + */ + QuantizationParams getQuantizationParams(); + + /** + * Serializes the quantization state to a byte array. + * + * @return a byte array representing the serialized state. + * @throws IOException if an I/O error occurs during serialization. + */ + byte[] toByteArray() throws IOException; +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java new file mode 100644 index 000000000..1f378e0dc --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.experimental.UtilityClass; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; + +import java.io.IOException; + +/** + * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing + * QuantizationState objects along with their specific data. + */ +@UtilityClass +class QuantizationStateSerializer { + + /** + * A functional interface for deserializing specific data associated with a QuantizationState. + */ + @FunctionalInterface + interface SerializableDeserializer { + QuantizationState deserialize(StreamInput in) throws IOException; + } + + /** + * Serializes the QuantizationState and specific data into a byte array. + * + * @param state The QuantizationState to serialize. + * @return A byte array representing the serialized state and specific data. + * @throws IOException If an I/O error occurs during serialization. + */ + static byte[] serialize(QuantizationState state) throws IOException { + try (BytesStreamOutput out = new BytesStreamOutput()) { + state.writeTo(out); + return out.bytes().toBytesRef().bytes; + } + } + + /** + * Deserializes a QuantizationState and its specific data from a byte array. + * + * @param bytes The byte array containing the serialized data. + * @param deserializer The deserializer for the specific data associated with the state. + * @return The deserialized QuantizationState including its specific data. + * @throws IOException If an I/O error occurs during deserialization. + */ + static QuantizationState deserialize(byte[] bytes, SerializableDeserializer deserializer) throws IOException { + try (StreamInput in = StreamInput.wrap(bytes)) { + return deserializer.deserialize(in); + } + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java new file mode 100644 index 000000000..54ebe311c --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.requests; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * TrainingRequest represents a request for training a quantizer. + * + * @param the type of vectors to be trained. + */ +@Getter +@AllArgsConstructor +public abstract class TrainingRequest { + /** + * The total number of vectors in one segment. + */ + private final int totalNumberOfVectors; + + /** + * Returns the vector corresponding to the specified document ID. + * + * @param docId the document ID. + * @return the vector corresponding to the specified document ID. + */ + public abstract T getVectorByDocId(int docId); +} diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java b/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java new file mode 100644 index 000000000..fe470ed74 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import lombok.experimental.UtilityClass; + +/** + * The BitPacker class provides utility methods for quantizing floating-point vectors and packing the resulting bits + * into a pre-allocated byte array. This class supports both single-bit and multi-bit quantization scenarios, + * enabling efficient storage and transmission of quantized vectors. + * + *

+ * The methods in this class are designed to be used by quantizers that need to convert floating-point vectors + * into compact binary representations by comparing them against quantization thresholds. + *

+ * + *

+ * This class is marked as a utility class using Lombok's {@link lombok.experimental.UtilityClass} annotation, + * making it a singleton and preventing instantiation. + *

+ */ +@UtilityClass +class BitPacker { + + /** + * Quantizes a given floating-point vector and packs the resulting quantized bits into a provided byte array. + * This method operates by comparing each element of the input vector against corresponding thresholds + * and encoding the results into a compact binary format using the specified number of bits per coordinate. + * + *

+ * The method supports multi-bit quantization where each coordinate of the input vector can be represented + * by multiple bits. For example, with 2-bit quantization, each coordinate is encoded into 2 bits, allowing + * for four distinct levels of quantization per coordinate. + *

+ * + *

+ * Example: + *

+ *

+ * Consider a vector with 3 coordinates: [1.2, 3.4, 5.6] and thresholds: + *

+ *
+     * thresholds = {
+     *     {1.0, 3.0, 5.0},  // First bit thresholds
+     *     {1.5, 3.5, 5.5}   // Second bit thresholds
+     * };
+     * 
+ *

+ * If the number of bits per coordinate is 2, the quantization process will proceed as follows: + *

+ *
    + *
  • First bit comparison: + *
      + *
    • 1.2 > 1.0 -> 1
    • + *
    • 3.4 > 3.0 -> 1
    • + *
    • 5.6 > 5.0 -> 1
    • + *
    + *
  • + *
  • Second bit comparison: + *
      + *
    • 1.2 <= 1.5 -> 0
    • + *
    • 3.4 <= 3.5 -> 0
    • + *
    • 5.6 > 5.5 -> 1
    • + *
    + *
  • + *
+ *

+ * The resulting quantized bits will be 11 10 11, which is packed into the provided byte array. + * If there are fewer than 8 bits, the remaining bits in the byte are set to 0. + *

+ * + *

+ * Packing Process: + * The quantized bits are packed into the byte array. The first coordinate's bits are stored in the most + * significant positions of the first byte, followed by the second coordinate, and so on. In the example + * above, the resulting byte array will have the following binary representation: + *

+ *
+     * packedBits = [11011000] // Only the first 6 bits are used, and the last two are set to 0.
+     * 
+ * + *

Bitwise Operations Explanation:

+ *
    + *
  • byteIndex: This is calculated using byteIndex = bitPosition >> 3, which is equivalent to bitPosition / 8. It determines which byte in the byte array the current bit should be placed in.
  • + *
  • bitIndex: This is calculated using bitIndex = 7 - (bitPosition & 7), which is equivalent to 7 - (bitPosition % 8). It determines the exact bit position within the byte.
  • + *
  • Setting the bit: The bit is set using packedBits[byteIndex] |= (1 << bitIndex). This shifts a 1 into the correct bit position and ORs it with the existing byte value to set the bit.
  • + *
+ * + * @param vector the floating-point vector to be quantized. + * @param thresholds a 2D array representing the quantization thresholds. The first dimension corresponds to the number of bits per coordinate, and the second dimension corresponds to the vector's length. + * @param bitsPerCoordinate the number of bits used per coordinate, determining the granularity of the quantization. + * @param packedBits the byte array where the quantized bits will be packed. + */ + void quantizeAndPackBits(final float[] vector, final float[][] thresholds, final int bitsPerCoordinate, byte[] packedBits) { + int vectorLength = vector.length; + + for (int i = 0; i < bitsPerCoordinate; i++) { + for (int j = 0; j < vectorLength; j++) { + if (vector[j] > thresholds[i][j]) { + int bitPosition = i * vectorLength + j; + // Calculate the index of the byte in the packedBits array. + int byteIndex = bitPosition >> 3; // Equivalent to bitPosition / 8 + // Calculate the bit index within the byte. + int bitIndex = 7 - (bitPosition & 7); // Equivalent to 7 - (bitPosition % 8) + // Set the bit at the calculated position. + packedBits[byteIndex] |= (1 << bitIndex); // Set the bit at bitIndex + } + } + } + } + + /** + * Overloaded method to quantize a vector using single-bit quantization and pack the results into a provided byte array. + * + *

+ * This method is specifically designed for one-bit quantization scenarios, where each coordinate of the + * vector is represented by a single bit indicating whether the value is above or below the threshold. + *

+ * + *

Example:

+ *

+ * If we have a vector [1.2, 3.4, 5.6] and thresholds [2.0, 3.0, 4.0], the quantization process will be: + *

+ *
    + *
  • 1.2 < 2.0 -> 0
  • + *
  • 3.4 > 3.0 -> 1
  • + *
  • 5.6 > 4.0 -> 1
  • + *
+ *

+ * The quantized vector will be [0, 1, 1]. + *

+ * + * @param vector the vector to quantize. + * @param thresholds the thresholds for quantization, where each element represents the threshold for a corresponding coordinate. + * @param packedBits the byte array where the quantized bits will be packed. + */ + void quantizeAndPackBits(final float[] vector, final float[] thresholds, byte[] packedBits) { + quantizeAndPackBits(vector, new float[][] { thresholds }, 1, packedBits); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java new file mode 100644 index 000000000..dcf825a6a --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -0,0 +1,186 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplerType; +import org.opensearch.knn.quantization.sampler.SamplingFactory; + +/** + * MultiBitScalarQuantizer is responsible for quantizing vectors into multi-bit representations per dimension. + * Unlike the OneBitScalarQuantizer, which uses a single bit per dimension to represent whether a value is above + * or below a mean threshold, the MultiBitScalarQuantizer allows for multiple bits per dimension, enabling more + * granular and precise quantization. + * + *

+ * In a OneBitScalarQuantizer, each dimension of a vector is compared to a single threshold (the mean), and a single + * bit is used to indicate whether the value is above or below that threshold. This results in a very coarse + * representation where each dimension is either "on" or "off." + *

+ * + *

+ * The MultiBitScalarQuantizer, on the other hand, uses multiple thresholds per dimension. For example, in a 2-bit + * quantization scheme, three thresholds are used to divide each dimension into four possible regions. Each region + * is represented by a unique 2-bit value. This allows for a much finer representation of the data, capturing more + * nuances in the variation of each dimension. + *

+ * + *

+ * The thresholds in MultiBitScalarQuantizer are calculated based on the mean and standard deviation of the sampled + * vectors for each dimension. Here's how it works: + *

+ * + *
    + *
  • First, the mean and standard deviation are computed for each dimension across the sampled vectors.
  • + *
  • For each bit used in the quantization (e.g., 2 bits per coordinate), the thresholds are calculated + * using a linear combination of the mean and the standard deviation. The combination coefficients are + * determined by the number of bits, allowing the thresholds to split the data into equal probability regions. + *
  • + *
  • For example, in a 2-bit quantization (which divides data into four regions), the thresholds might be + * set at points corresponding to -1 standard deviation, 0 standard deviations (mean), and +1 standard deviation. + * This ensures that the data is evenly split into four regions, each represented by a 2-bit value. + *
  • + *
+ * + *

+ * The number of bits per coordinate is determined by the type of scalar quantization being applied, such as 2-bit + * or 4-bit quantization. The increased number of bits per coordinate in MultiBitScalarQuantizer allows for better + * preservation of information during the quantization process, making it more suitable for tasks where precision + * is crucial. However, this comes at the cost of increased storage and computational complexity compared to the + * simpler OneBitScalarQuantizer. + *

+ */ +public class MultiBitScalarQuantizer implements Quantizer { + private final int bitsPerCoordinate; // Number of bits used to quantize each dimension + private final int samplingSize; // Sampling size for training + private final Sampler sampler; // Sampler for training + private static final boolean IS_TRAINING_REQUIRED = true; + // Currently Lucene has sampling size as + // 25000 for segment level training , Keeping same + // to having consistent, Will revisit + // if this requires change + private static final int DEFAULT_SAMPLE_SIZE = 25000; + + /** + * Constructs a MultiBitScalarQuantizer with a specified number of bits per coordinate. + * + * @param bitsPerCoordinate the number of bits used per coordinate for quantization. + */ + public MultiBitScalarQuantizer(final int bitsPerCoordinate) { + this(bitsPerCoordinate, DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplerType.RESERVOIR)); + } + + /** + * Constructs a MultiBitScalarQuantizer with a specified number of bits per coordinate and sampling size. + * + * @param bitsPerCoordinate the number of bits used per coordinate for quantization. + * @param samplingSize the number of samples to use for training. + * @param sampler the sampler to use for training. + */ + public MultiBitScalarQuantizer(final int bitsPerCoordinate, final int samplingSize, final Sampler sampler) { + if (bitsPerCoordinate < 2) { + throw new IllegalArgumentException("bitsPerCoordinate must be greater than or equal to 2 for multibit quantizer."); + } + this.bitsPerCoordinate = bitsPerCoordinate; + this.samplingSize = samplingSize; + this.sampler = sampler; + } + + /** + * Trains the quantizer based on the provided training request, which should be of type SamplingTrainingRequest. + * The training process calculates the mean and standard deviation for each dimension and then determines + * threshold values for quantization based on these statistics. + * + * @param trainingRequest the request containing the data and parameters for training. + * @return a MultiBitScalarQuantizationState containing the computed thresholds. + */ + @Override + public QuantizationState train(final TrainingRequest trainingRequest) { + int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + int dimension = trainingRequest.getVectorByDocId(sampledIndices[0]).length; + float[] meanArray = new float[dimension]; + float[] stdDevArray = new float[dimension]; + // Calculate sum, mean, and standard deviation in one pass + QuantizerHelper.calculateMeanAndStdDev(trainingRequest, sampledIndices, meanArray, stdDevArray); + float[][] thresholds = calculateThresholds(meanArray, stdDevArray, dimension); + ScalarQuantizationParams params = (bitsPerCoordinate == 2) + ? new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT) + : new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + return new MultiBitScalarQuantizationState(params, thresholds); + } + + /** + * Quantizes the provided vector using the provided quantization state, producing a quantized output. + * The vector is quantized based on the thresholds in the quantization state. + * + * @param vector the vector to quantize. + * @param state the quantization state containing threshold information. + * @param output the QuantizationOutput object to store the quantized representation of the vector. + */ + @Override + public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput output) { + if (vector == null) { + throw new IllegalArgumentException("Vector to quantize must not be null."); + } + validateState(state); + MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) state; + float[][] thresholds = multiBitState.getThresholds(); + if (thresholds == null || thresholds[0].length != vector.length) { + throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); + } + // Prepare and get the writable array + byte[] writableArray = output.prepareAndGetWritableQuantizedVector(bitsPerCoordinate, vector.length); + BitPacker.quantizeAndPackBits(vector, thresholds, bitsPerCoordinate, writableArray); + } + + /** + * Calculates the thresholds for quantization based on mean and standard deviation. + * + * @param meanArray the mean for each dimension. + * @param stdDevArray the standard deviation for each dimension. + * @param dimension the number of dimensions in the vectors. + * @return the thresholds for quantization. + */ + private float[][] calculateThresholds(final float[] meanArray, final float[] stdDevArray, final int dimension) { + float[][] thresholds = new float[bitsPerCoordinate][dimension]; + float coef = bitsPerCoordinate + 1; + for (int i = 0; i < bitsPerCoordinate; i++) { + float iCoef = -1 + 2 * (i + 1) / coef; + for (int j = 0; j < dimension; j++) { + thresholds[i][j] = meanArray[j] + iCoef * stdDevArray[j]; + } + } + return thresholds; + } + + /** + * Validates the quantization state to ensure it is of the expected type. + * + * @param state the quantization state to validate. + * @throws IllegalArgumentException if the state is not an instance of MultiBitScalarQuantizationState. + */ + private void validateState(final QuantizationState state) { + if (!(state instanceof MultiBitScalarQuantizationState)) { + throw new IllegalArgumentException("Quantization state must be of type MultiBitScalarQuantizationState."); + } + } + + /** + * Returns the number of bits per coordinate used by this quantizer. + * + * @return the number of bits per coordinate. + */ + public int getBitsPerCoordinate() { + return bitsPerCoordinate; + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java new file mode 100644 index 000000000..41602dfd2 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplerType; +import org.opensearch.knn.quantization.sampler.SamplingFactory; + +/** + * OneBitScalarQuantizer is responsible for quantizing vectors using a single bit per dimension. + * It computes the mean of each dimension during training and then uses these means as thresholds + * for quantizing the vectors. + */ +public class OneBitScalarQuantizer implements Quantizer { + private final int samplingSize; // Sampling size for training + private static final boolean IS_TRAINING_REQUIRED = true; + private final Sampler sampler; // Sampler for training + // Currently Lucene has sampling size as + // 25000 for segment level training , Keeping same + // to having consistent, Will revisit + // if this requires change + private static final int DEFAULT_SAMPLE_SIZE = 25000; + + /** + * Constructs a OneBitScalarQuantizer with a default sampling size of 25000. + */ + public OneBitScalarQuantizer() { + this(DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplerType.RESERVOIR)); + } + + /** + * Constructs a OneBitScalarQuantizer with a specified sampling size. + * + * @param samplingSize the number of samples to use for training. + */ + public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) { + + this.samplingSize = samplingSize; + this.sampler = sampler; + } + + /** + * Trains the quantizer by calculating the mean of each dimension from the sampled vectors. + * These means are used as thresholds in the quantization process. + * + * @param trainingRequest the request containing the data and parameters for training. + * @return a OneBitScalarQuantizationState containing the calculated means. + */ + @Override + public QuantizationState train(final TrainingRequest trainingRequest) { + int[] sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(trainingRequest, sampledDocIds); + return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds); + } + + /** + * Quantizes the provided vector using the given quantization state. + * It compares each dimension of the vector against the corresponding mean (threshold) to determine the quantized value. + * + * @param vector the vector to quantize. + * @param state the quantization state containing the means for each dimension. + * @param output the QuantizationOutput object to store the quantized representation of the vector. + */ + @Override + public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput output) { + if (vector == null) { + throw new IllegalArgumentException("Vector to quantize must not be null."); + } + validateState(state); + OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState) state; + float[] thresholds = binaryState.getMeanThresholds(); + if (thresholds == null || thresholds.length != vector.length) { + throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); + } + // Prepare and get the writable array + byte[] writableArray = output.prepareAndGetWritableQuantizedVector(1, vector.length); + BitPacker.quantizeAndPackBits(vector, thresholds, writableArray); + } + + /** + * Validates the quantization state to ensure it is of the expected type. + * + * @param state the quantization state to validate. + * @throws IllegalArgumentException if the state is not an instance of OneBitScalarQuantizationState. + */ + private void validateState(final QuantizationState state) { + if (!(state instanceof OneBitScalarQuantizationState)) { + throw new IllegalArgumentException("Quantization state must be of type OneBitScalarQuantizationState."); + } + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java new file mode 100644 index 000000000..c0b297f5d --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; + +/** + * The Quantizer interface defines the methods required for training and quantizing vectors + * in the context of K-Nearest Neighbors (KNN) and similar machine learning tasks. + * It supports training to determine quantization parameters and quantizing data vectors + * based on these parameters. + * + * @param The type of the vector or data to be quantized. + * @param The type of the quantized output, typically a compressed or encoded representation. + */ +public interface Quantizer { + + /** + * Trains the quantizer based on the provided training request. The training process typically + * involves learning parameters that can be used to quantize vectors. + * + * @param trainingRequest the request containing data and parameters for training. + * @return a QuantizationState containing the learned parameters. + */ + QuantizationState train(TrainingRequest trainingRequest); + + /** + * Quantizes the provided vector using the specified quantization state. + * + * @param vector the vector to quantize. + * @param state the quantization state containing parameters for quantization. + * @param output the QuantizationOutput object to store the quantized representation of the vector. + */ + void quantize(T vector, QuantizationState state, QuantizationOutput output); +} diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java new file mode 100644 index 000000000..16f969973 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import lombok.experimental.UtilityClass; + +/** + * Utility class providing common methods for quantizer operations, such as parameter validation and + * extraction. This class is designed to be used with various quantizer implementations that require + * consistent handling of training requests and sampled indices. + */ +@UtilityClass +class QuantizerHelper { + /** + * Calculates the mean vector from a set of sampled vectors. + * + * @param samplingRequest The {@link TrainingRequest} containing the dataset and methods to access vectors by their indices. + * @param sampledIndices An array of indices representing the sampled vectors to be used for mean calculation. + * @return A float array representing the mean vector of the sampled vectors. + * @throws IllegalArgumentException If any of the vectors at the sampled indices are null. + * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors. + */ + static float[] calculateMeanThresholds(TrainingRequest samplingRequest, int[] sampledIndices) { + int totalSamples = sampledIndices.length; + float[] mean = null; + for (int docId : sampledIndices) { + float[] vector = samplingRequest.getVectorByDocId(docId); + if (vector == null) { + throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); + } + if (mean == null) { + mean = new float[vector.length]; + } + for (int j = 0; j < vector.length; j++) { + mean[j] += vector[j]; + } + } + if (mean == null) { + throw new IllegalStateException("Mean array should not be null after processing vectors."); + } + for (int j = 0; j < mean.length; j++) { + mean[j] /= totalSamples; + } + return mean; + } + + /** + * Calculates the mean and StdDev per dimension for sampled vectors. + * + * @param trainingRequest the request containing the data and parameters for training. + * @param sampledIndices the indices of the sampled vectors. + * @param meanArray the array to store the sum and then the mean of each dimension. + * @param stdDevArray the array to store the sum of squares and then the standard deviation of each dimension. + */ + static void calculateMeanAndStdDev( + TrainingRequest trainingRequest, + int[] sampledIndices, + float[] meanArray, + float[] stdDevArray + ) { + int totalSamples = sampledIndices.length; + int dimension = meanArray.length; + for (int docId : sampledIndices) { + float[] vector = trainingRequest.getVectorByDocId(docId); + if (vector == null) { + throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); + } + for (int j = 0; j < dimension; j++) { + meanArray[j] += vector[j]; + stdDevArray[j] += vector[j] * vector[j]; + } + } + + // Calculate mean and standard deviation in one pass + for (int j = 0; j < dimension; j++) { + meanArray[j] = meanArray[j] / totalSamples; + stdDevArray[j] = (float) Math.sqrt((stdDevArray[j] / totalSamples) - (meanArray[j] * meanArray[j])); + } + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java new file mode 100644 index 000000000..020efe54f --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +import lombok.NoArgsConstructor; + +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.IntStream; + +/** + * ReservoirSampler implements the Sampler interface and provides a method for sampling + * a specified number of indices from a total number of vectors using the reservoir sampling algorithm. + * This algorithm is particularly useful for randomly sampling a subset of data from a larger set + * when the total size of the dataset is unknown or very large. + */ +@NoArgsConstructor +final class ReservoirSampler implements Sampler { + /** + * Singleton instance holder. + */ + private static ReservoirSampler instance; + + /** + * Provides the singleton instance of ReservoirSampler. + * + * @return the singleton instance of ReservoirSampler. + */ + public static synchronized ReservoirSampler getInstance() { + if (instance == null) { + instance = new ReservoirSampler(); + } + return instance; + } + + /** + * Samples indices from the range [0, totalNumberOfVectors). + * If the total number of vectors is less than or equal to the sample size, it returns all indices. + * Otherwise, it uses the reservoir sampling algorithm to select a random subset. + * + * @param totalNumberOfVectors the total number of vectors to sample from. + * @param sampleSize the number of indices to sample. + * @return an array of sampled indices. + */ + @Override + public int[] sample(final int totalNumberOfVectors, final int sampleSize) { + if (totalNumberOfVectors <= sampleSize) { + return IntStream.range(0, totalNumberOfVectors).toArray(); + } + return reservoirSampleIndices(totalNumberOfVectors, sampleSize); + } + + /** + * Applies the reservoir sampling algorithm to select a random sample of indices. + * This method ensures that each index in the range [0, numVectors) has an equal probability + * of being included in the sample. + * + * Reservoir sampling is particularly useful for selecting a random sample from a large or unknown-sized dataset. + * For more information on the algorithm, see the following link: + * Reservoir Sampling - Wikipedia + * + * @param numVectors the total number of vectors. + * @param sampleSize the number of indices to sample. + * @return an array of sampled indices. + */ + private int[] reservoirSampleIndices(final int numVectors, final int sampleSize) { + int[] indices = new int[sampleSize]; + + // Initialize the reservoir with the first sampleSize elements + for (int i = 0; i < sampleSize; i++) { + indices[i] = i; + } + + // Replace elements with gradually decreasing probability + for (int i = sampleSize; i < numVectors; i++) { + int j = ThreadLocalRandom.current().nextInt(i + 1); + if (j < sampleSize) { + indices[j] = i; + } + } + + // Sort the sampled indices + Arrays.sort(indices); + + return indices; + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java new file mode 100644 index 000000000..5ff385972 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +/** + * The Sampler interface defines the contract for sampling strategies + * used in various quantization processes. Implementations of this + * interface should provide specific strategies for selecting a sample + * from a given set of vectors. + */ +public interface Sampler { + + /** + * Samples a subset of indices from the total number of vectors. + * + * @param totalNumberOfVectors the total number of vectors available. + * @param sampleSize the number of vectors to be sampled. + * @return an array of integers representing the indices of the sampled vectors. + * @throws IllegalArgumentException if the sample size is greater than the total number of vectors. + */ + int[] sample(int totalNumberOfVectors, int sampleSize); +} diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java new file mode 100644 index 000000000..cd9b301df --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +/** + * SamplerType is an enumeration of the different types of samplers that can be created by the factory. + */ +public enum SamplerType { + RESERVOIR, // Represents a reservoir sampling strategy + // Add more enum values here for additional sampler types +} diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java new file mode 100644 index 000000000..80fe5bdae --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; + +/** + * SamplingFactory is a factory class for creating instances of Sampler. + * It uses the factory design pattern to encapsulate the creation logic for different types of samplers. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class SamplingFactory { + + /** + * Creates and returns a Sampler instance based on the specified SamplerType. + * + * @param samplerType the type of sampler to create. + * @return a Sampler instance. + * @throws IllegalArgumentException if the sampler type is not supported. + */ + public static Sampler getSampler(final SamplerType samplerType) { + switch (samplerType) { + case RESERVOIR: + return ReservoirSampler.getInstance(); + // Add more cases for different samplers here + default: + throw new IllegalArgumentException("Unsupported sampler type: " + samplerType); + } + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java new file mode 100644 index 000000000..99621a0e5 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.enums; + +import org.opensearch.knn.KNNTestCase; + +import java.util.HashSet; +import java.util.Set; + +public class ScalarQuantizationTypeTests extends KNNTestCase { + public void testSQTypesValues() { + ScalarQuantizationType[] expectedValues = { + ScalarQuantizationType.ONE_BIT, + ScalarQuantizationType.TWO_BIT, + ScalarQuantizationType.FOUR_BIT }; + assertArrayEquals(expectedValues, ScalarQuantizationType.values()); + } + + public void testSQTypesValueOf() { + assertEquals(ScalarQuantizationType.ONE_BIT, ScalarQuantizationType.valueOf("ONE_BIT")); + assertEquals(ScalarQuantizationType.TWO_BIT, ScalarQuantizationType.valueOf("TWO_BIT")); + assertEquals(ScalarQuantizationType.FOUR_BIT, ScalarQuantizationType.valueOf("FOUR_BIT")); + } + + public void testUniqueSQTypeValues() { + Set uniqueIds = new HashSet<>(); + for (ScalarQuantizationType type : ScalarQuantizationType.values()) { + boolean added = uniqueIds.add(type.getId()); + assertTrue("Duplicate value found: " + type.getId(), added); + } + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java new file mode 100644 index 000000000..3474b7ec9 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import org.junit.Before; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +import java.lang.reflect.Field; +import java.util.concurrent.atomic.AtomicBoolean; + +public class QuantizerFactoryTests extends KNNTestCase { + + @Before + public void resetIsRegisteredFlag() throws NoSuchFieldException, IllegalAccessException { + Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered"); + isRegisteredField.setAccessible(true); + AtomicBoolean isRegistered = (AtomicBoolean) isRegisteredField.get(null); + isRegistered.set(false); + } + + public void test_Lazy_Registration() { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + ScalarQuantizationParams paramsTwoBit = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + ScalarQuantizationParams paramsFourBit = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + assertFalse(isRegisteredFieldAccessible()); + Quantizer quantizer = QuantizerFactory.getQuantizer(params); + Quantizer quantizerTwoBit = QuantizerFactory.getQuantizer(paramsTwoBit); + Quantizer quantizerFourBit = QuantizerFactory.getQuantizer(paramsFourBit); + assertTrue(quantizerFourBit instanceof MultiBitScalarQuantizer); + assertTrue(quantizerTwoBit instanceof MultiBitScalarQuantizer); + assertTrue(quantizer instanceof OneBitScalarQuantizer); + assertTrue(isRegisteredFieldAccessible()); + } + + public void testGetQuantizer_withNullParams() { + try { + QuantizerFactory.getQuantizer(null); + fail("Expected IllegalArgumentException"); + } catch (IllegalArgumentException e) { + assertEquals("Quantization parameters must not be null.", e.getMessage()); + } + } + + private boolean isRegisteredFieldAccessible() { + try { + Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered"); + isRegisteredField.setAccessible(true); + AtomicBoolean isRegistered = (AtomicBoolean) isRegisteredField.get(null); + return isRegistered.get(); + } catch (NoSuchFieldException | IllegalAccessException e) { + fail("Failed to access isRegistered field."); + return false; + } + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java new file mode 100644 index 000000000..dec34e632 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.factory; + +import org.junit.BeforeClass; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; +import org.opensearch.knn.quantization.quantizer.Quantizer; + +public class QuantizerRegistryTests extends KNNTestCase { + + @BeforeClass + public static void setup() { + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT), + new OneBitScalarQuantizer() + ); + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.TWO_BIT), + new MultiBitScalarQuantizer(2) + ); + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT), + new MultiBitScalarQuantizer(4) + ); + } + + public void testRegisterAndGetQuantizer() { + // Test for OneBitScalarQuantizer + ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + Quantizer oneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); + assertTrue(oneBitQuantizer instanceof OneBitScalarQuantizer); + + // Test for MultiBitScalarQuantizer (2-bit) + ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + Quantizer twoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); + assertTrue(twoBitQuantizer instanceof MultiBitScalarQuantizer); + assertEquals(2, ((MultiBitScalarQuantizer) twoBitQuantizer).getBitsPerCoordinate()); + + // Test for MultiBitScalarQuantizer (4-bit) + ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + Quantizer fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); + assertTrue(fourBitQuantizer instanceof MultiBitScalarQuantizer); + assertEquals(4, ((MultiBitScalarQuantizer) fourBitQuantizer).getBitsPerCoordinate()); + } + + public void testQuantizerRegistryIsSingleton() { + // Ensure the same instance is returned for the same type identifier + ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + Quantizer firstOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); + Quantizer secondOneBitQuantizer = QuantizerRegistry.getQuantizer(oneBitParams); + assertSame(firstOneBitQuantizer, secondOneBitQuantizer); + + // Ensure the same instance is returned for the same type identifier (2-bit) + ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + Quantizer firstTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); + Quantizer secondTwoBitQuantizer = QuantizerRegistry.getQuantizer(twoBitParams); + assertSame(firstTwoBitQuantizer, secondTwoBitQuantizer); + + // Ensure the same instance is returned for the same type identifier (4-bit) + ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + Quantizer firstFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); + Quantizer secondFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); + assertSame(firstFourBitQuantizer, secondFourBitQuantizer); + } + + public void testRegisterQuantizerThrowsExceptionWhenAlreadyRegistered() { + ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + + // Attempt to register the same quantizer again should throw an exception + assertThrows(IllegalArgumentException.class, () -> { + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT), + new OneBitScalarQuantizer() + ); + }); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java new file mode 100644 index 000000000..fa25e8e80 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizationState; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; + +import java.io.IOException; + +public class QuantizationStateSerializerTests extends KNNTestCase { + + public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + float[] mean = new float[] { 0.1f, 0.2f, 0.3f }; + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + + // Serialize + byte[] serialized = state.toByteArray(); + + OneBitScalarQuantizationState deserialized = OneBitScalarQuantizationState.fromByteArray(serialized); + + assertArrayEquals(mean, deserialized.getMeanThresholds(), 0.0f); + assertEquals(params, deserialized.getQuantizationParams()); + } + + public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws IOException { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + float[][] thresholds = new float[][] { { 0.1f, 0.2f, 0.3f }, { 0.4f, 0.5f, 0.6f } }; + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + // Serialize + byte[] serialized = state.toByteArray(); + MultiBitScalarQuantizationState deserialized = MultiBitScalarQuantizationState.fromByteArray(serialized); + + for (int i = 0; i < thresholds.length; i++) { + assertArrayEquals(thresholds[i], deserialized.getThresholds()[i], 0.0f); + } + assertEquals(params, deserialized.getQuantizationParams()); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java new file mode 100644 index 000000000..35edf49e2 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizationState; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; + +import java.io.IOException; + +public class QuantizationStateTests extends KNNTestCase { + + public void testOneBitScalarQuantizationStateSerialization() throws IOException { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + float[] mean = { 1.0f, 2.0f, 3.0f }; + + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + + // Serialize + byte[] serializedState = state.toByteArray(); + + // Deserialize + StreamInput in = StreamInput.wrap(serializedState); + OneBitScalarQuantizationState deserializedState = new OneBitScalarQuantizationState(in); + + float delta = 0.0001f; + assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta); + assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); + } + + public void testMultiBitScalarQuantizationStateSerialization() throws IOException { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + float[][] thresholds = { { 0.5f, 1.5f, 2.5f }, { 1.0f, 2.0f, 3.0f } }; + + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + byte[] serializedState = state.toByteArray(); + + // Deserialize + StreamInput in = StreamInput.wrap(serializedState); + MultiBitScalarQuantizationState deserializedState = new MultiBitScalarQuantizationState(in); + + float delta = 0.0001f; + for (int i = 0; i < thresholds.length; i++) { + assertArrayEquals(thresholds[i], deserializedState.getThresholds()[i], delta); + } + assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); + } + + public void testSerializationWithDifferentVersions() throws IOException { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + float[] mean = { 1.0f, 2.0f, 3.0f }; + + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + byte[] serializedState = state.toByteArray(); + StreamInput in = StreamInput.wrap(serializedState); + OneBitScalarQuantizationState deserializedState = new OneBitScalarQuantizationState(in); + + float delta = 0.0001f; + assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta); + assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java new file mode 100644 index 000000000..ad6a44686 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java @@ -0,0 +1,107 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; + +import java.io.IOException; + +public class MultiBitScalarQuantizerTests extends KNNTestCase { + + public void testTrain_twoBit() { + float[][] vectors = { + { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, + { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, + { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } }; + MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + TrainingRequest request = new MockTrainingRequest(params, vectors); + QuantizationState state = twoBitQuantizer.train(request); + + assertTrue(state instanceof MultiBitScalarQuantizationState); + MultiBitScalarQuantizationState mbState = (MultiBitScalarQuantizationState) state; + assertNotNull(mbState.getThresholds()); + assertEquals(2, mbState.getThresholds().length); // 2-bit quantization should have 2 thresholds + } + + public void testTrain_fourBit() { + MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4); + float[][] vectors = { + { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, + { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, + { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } }; + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + TrainingRequest request = new MockTrainingRequest(params, vectors); + QuantizationState state = fourBitQuantizer.train(request); + + assertTrue(state instanceof MultiBitScalarQuantizationState); + MultiBitScalarQuantizationState mbState = (MultiBitScalarQuantizationState) state; + assertNotNull(mbState.getThresholds()); + assertEquals(4, mbState.getThresholds().length); // 4-bit quantization should have 4 thresholds + } + + public void testQuantize_twoBit() throws IOException { + MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f }; + float[][] thresholds = { { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } }; + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + twoBitQuantizer.quantize(vector, state, output); + assertNotNull(output.getQuantizedVector()); + } + + public void testQuantize_fourBit() throws IOException { + MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4); + float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f }; + float[][] thresholds = { + { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, + { 1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f, 8.1f }, + { 1.2f, 2.2f, 3.2f, 4.2f, 5.2f, 6.2f, 7.2f, 8.2f }, + { 1.3f, 2.3f, 3.3f, 4.3f, 5.3f, 6.3f, 7.3f, 8.3f } }; + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + fourBitQuantizer.quantize(vector, state, output); + assertNotNull(output.getQuantizedVector()); + } + + public void testQuantize_withNullVector() throws IOException { + MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + expectThrows( + IllegalArgumentException.class, + () -> twoBitQuantizer.quantize( + null, + new MultiBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT), new float[2][8]), + output + ) + ); + } + + // Mock classes for testing + private static class MockTrainingRequest extends TrainingRequest { + private final float[][] vectors; + + public MockTrainingRequest(ScalarQuantizationParams params, float[][] vectors) { + super(vectors.length); + this.vectors = vectors; + } + + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java new file mode 100644 index 000000000..28be260d7 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplerType; +import org.opensearch.knn.quantization.sampler.SamplingFactory; + +import java.io.IOException; + +public class OneBitScalarQuantizerTests extends KNNTestCase { + + public void testTrain_withTrainingRequired() { + float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; + + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + TrainingRequest originalRequest = new TrainingRequest(vectors.length) { + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + }; + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + QuantizationState state = quantizer.train(originalRequest); + + assertTrue(state instanceof OneBitScalarQuantizationState); + float[] meanThresholds = ((OneBitScalarQuantizationState) state).getMeanThresholds(); + assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, meanThresholds, 0.001f); + } + + public void testQuantize_withState() throws IOException { + float[] vector = { 3.0f, 6.0f, 9.0f }; + float[] thresholds = { 4.0f, 5.0f, 6.0f }; + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + thresholds + ); + + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + quantizer.quantize(vector, state, output); + + assertNotNull(output); + byte[] expectedPackedBits = new byte[] { 0b01100000 }; // or 96 in decimal + assertArrayEquals(expectedPackedBits, output.getQuantizedVector()); + } + + public void testQuantize_withNullVector() throws IOException { + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + new float[] { 0.0f } + ); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state, output)); + } + + public void testQuantize_withInvalidState() throws IOException { + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + float[] vector = { 1.0f, 2.0f, 3.0f }; + QuantizationState invalidState = new QuantizationState() { + @Override + public ScalarQuantizationParams getQuantizationParams() { + return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + } + + @Override + public byte[] toByteArray() { + return new byte[0]; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + // Empty implementation for test + } + }; + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, invalidState, output)); + } + + public void testQuantize_withMismatchedDimensions() throws IOException { + OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); + float[] vector = { 1.0f, 2.0f, 3.0f }; + float[] thresholds = { 4.0f, 5.0f }; + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState( + new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), + thresholds + ); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state, output)); + } + + public void testCalculateMean() { + float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; + + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + TrainingRequest samplingRequest = new TrainingRequest(vectors.length) { + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + }; + + Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); + int[] sampledIndices = sampler.sample(vectors.length, 3); + float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices); + assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, meanThresholds, 0.001f); + } + + public void testCalculateMean_withNullVector() { + float[][] vectors = { { 1.0f, 2.0f, 3.0f }, null, { 7.0f, 8.0f, 9.0f } }; + + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + TrainingRequest samplingRequest = new TrainingRequest(vectors.length) { + @Override + public float[] getVectorByDocId(int docId) { + return vectors[docId]; + } + }; + + Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); + int[] sampledIndices = sampler.sample(vectors.length, 3); + expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices)); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java new file mode 100644 index 000000000..59952eb10 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java @@ -0,0 +1,63 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +import org.opensearch.knn.KNNTestCase; + +import java.util.Arrays; +import java.util.stream.IntStream; + +public class ReservoirSamplerTests extends KNNTestCase { + + public void testSampleLessThanSampleSize() { + ReservoirSampler sampler = ReservoirSampler.getInstance(); + int totalNumberOfVectors = 5; + int sampleSize = 10; + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray(); + assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); + } + + public void testSampleEqualToSampleSize() { + ReservoirSampler sampler = ReservoirSampler.getInstance(); + int totalNumberOfVectors = 10; + int sampleSize = 10; + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray(); + assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); + } + + public void testSampleRandomness() { + ReservoirSampler sampler1 = ReservoirSampler.getInstance(); + ReservoirSampler sampler2 = ReservoirSampler.getInstance(); + int totalNumberOfVectors = 100; + int sampleSize = 10; + + int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); + int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); + + // It's unlikely but possible for the two samples to be equal, so we just check they are sorted correctly + Arrays.sort(sampledIndices1); + Arrays.sort(sampledIndices2); + assertFalse("Sampled indices should be different", Arrays.equals(sampledIndices1, sampledIndices2)); + } + + public void testEdgeCaseZeroVectors() { + ReservoirSampler sampler = ReservoirSampler.getInstance(); + int totalNumberOfVectors = 0; + int sampleSize = 10; + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals("Sampled indices should be empty when there are zero vectors.", 0, sampledIndices.length); + } + + public void testEdgeCaseZeroSampleSize() { + ReservoirSampler sampler = ReservoirSampler.getInstance(); + int totalNumberOfVectors = 10; + int sampleSize = 0; + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals("Sampled indices should be empty when sample size is zero.", 0, sampledIndices.length); + } +} diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java new file mode 100644 index 000000000..db8772b70 --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +import org.opensearch.knn.KNNTestCase; + +public class SamplingFactoryTests extends KNNTestCase { + public void testGetSampler_withReservoir() { + Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); + assertTrue(sampler instanceof ReservoirSampler); + } + + public void testGetSampler_withUnsupportedType() { + expectThrows(NullPointerException.class, () -> SamplingFactory.getSampler(null)); // This should throw an exception + } +}