From 20f785b93f1f3373b712476fb9bf74dc66ef13c3 Mon Sep 17 00:00:00 2001 From: VIKASH TIWARI Date: Fri, 2 Aug 2024 22:35:49 -0700 Subject: [PATCH] Quantization Framework Implementation with 1bit and MultiBit Binary Quantizer Signed-off-by: VIKASH TIWARI --- CHANGELOG.md | 2 +- .../quantization/enums/QuantizationType.java | 34 ------ .../enums/ScalarQuantizationType.java | 32 ++++-- .../enums/ValueQuantizationType.java | 18 --- .../factory/QuantizerFactory.java | 6 +- .../factory/QuantizerRegistrar.java | 25 ++--- .../factory/QuantizerRegistry.java | 39 ++----- .../BinaryQuantizationOutput.java | 27 +++-- .../QuantizationOutput.java | 10 ++ .../QuantizationParams.java | 44 +++++--- .../models/quantizationParams/SQParams.java | 76 ++++++------- .../DefaultQuantizationState.java | 49 +++++++-- .../MultiBitScalarQuantizationState.java | 74 +++++++++---- .../OneBitScalarQuantizationState.java | 57 ++++++---- .../quantizationState/QuantizationState.java | 4 +- .../QuantizationStateSerializer.java | 68 ++++++++++++ .../models/requests/TrainingRequest.java | 55 +--------- .../{util => quantizer}/BitPacker.java | 18 ++- .../quantizer/MultiBitScalarQuantizer.java | 25 +++-- .../quantizer/OneBitScalarQuantizer.java | 22 ++-- .../knn/quantization/quantizer/Quantizer.java | 11 +- .../{util => quantizer}/QuantizerHelper.java | 48 +++----- .../sampler/ReservoirSampler.java | 50 ++++----- .../knn/quantization/sampler/Sampler.java | 19 +++- .../knn/quantization/sampler/SamplerType.java | 14 +++ .../quantization/sampler/SamplingFactory.java | 20 +--- .../util/QuantizationStateSerializer.java | 103 ------------------ .../enums/QuantizationTypeTests.java | 21 ---- .../knn/quantization/enums/SQTypesTests.java | 4 +- .../enums/ValueQuantizationTypeTests.java | 19 ---- .../factory/QuantizerFactoryTests.java | 10 -- .../factory/QuantizerRegistryTests.java | 31 +----- .../QuantizationStateSerializerTests.java | 4 +- .../QuantizationStateTests.java | 27 ++--- .../BitPackingUtilsTests.java | 4 +- .../MultiBitScalarQuantizerTests.java | 47 +++----- .../quantizer/OneBitScalarQuantizerTests.java | 54 ++++++--- .../sampler/ReservoirSamplerTests.java | 53 +++------ .../sampler/SamplingFactoryTests.java | 2 +- 39 files changed, 536 insertions(+), 690 deletions(-) delete mode 100644 src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java delete mode 100644 src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java create mode 100644 src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java rename src/main/java/org/opensearch/knn/quantization/{util => quantizer}/BitPacker.java (77%) rename src/main/java/org/opensearch/knn/quantization/{util => quantizer}/QuantizerHelper.java (67%) create mode 100644 src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java delete mode 100644 src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java delete mode 100644 src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java delete mode 100644 src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java rename src/test/java/org/opensearch/knn/quantization/{util => quantizer}/BitPackingUtilsTests.java (94%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 92af64ccb8..f47f8f20dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,4 +30,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920) * Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931) * Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925) -* Quantization Framework For Disk Optimized Vector Search and Implementation of Binary 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889) +* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889) diff --git a/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java deleted file mode 100644 index 4a2a17a574..0000000000 --- a/src/main/java/org/opensearch/knn/quantization/enums/QuantizationType.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.enums; - -/** - * The QuantizationType enum represents the different types of quantization - * that can be applied in the KNN. - * - * - */ -public enum QuantizationType { - /** - * Represents space quantization, typically involving dimensionality reduction - * or space partitioning techniques. - */ - SPACE, - - /** - * Represents value quantization, typically involving the conversion of continuous - * values into discrete ones. - */ - VALUE, -} diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java index 88290c6a86..60a9ac8367 100644 --- a/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java +++ b/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java @@ -6,29 +6,39 @@ package org.opensearch.knn.quantization.enums; /** - * The SQTypes enum defines the various scalar quantization types that can be used - * in the KNN for vector quantization. - * Each type corresponds to a different bit-width representation of the quantized values. + * The ScalarQuantizationType enum defines the various scalar quantization types that can be used + * for vector quantization. + * Each type corresponds to a different bit and byte representation of the quantized values. */ public enum ScalarQuantizationType { /** * ONE_BIT quantization uses a single bit per coordinate. + * In the future , if you change the name , Please don't change value as + * serlization and deserlization depends on this */ - ONE_BIT, + ONE_BIT(1), /** * TWO_BIT quantization uses two bits per coordinate. + * In the future , if you change the name , Please don't change value as + * serlization and deserlization depends on this */ - TWO_BIT, + TWO_BIT(2), /** * FOUR_BIT quantization uses four bits per coordinate. + * In the future , if you change the name , Please don't change value as + * serlization and deserlization depends on this */ - FOUR_BIT, + FOUR_BIT(4); - /** - * UNSUPPORTED_TYPE is used to denote quantization types that are not supported. - * This can be used as a placeholder or default value. - */ - UNSUPPORTED_TYPE + private final int id; + + ScalarQuantizationType(int id) { + this.id = id; + } + + public int getId() { + return id; + } } diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java deleted file mode 100644 index 43db46cf6e..0000000000 --- a/src/main/java/org/opensearch/knn/quantization/enums/ValueQuantizationType.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.enums; - -/** - * The ValueQuantizationType enum defines the types of value quantization techniques - * that can be applied in the KNN. - */ -public enum ValueQuantizationType { - /** - * SQ (Scalar Quantization) represents a method where each coordinate of the vector is quantized - * independently. - */ - SCALAR -} diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java index 985efd4cd1..b99f6ebdce 100644 --- a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerFactory.java @@ -5,6 +5,8 @@ package org.opensearch.knn.quantization.factory; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.quantizer.Quantizer; @@ -15,12 +17,10 @@ * based on the provided {@link QuantizationParams}. It uses a registry to look up the * appropriate quantizer implementation for the given quantization parameters. */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) public final class QuantizerFactory { private static final AtomicBoolean isRegistered = new AtomicBoolean(false); - // Private constructor to prevent instantiation - private QuantizerFactory() {} - /** * Ensures that default quantizers are registered. */ diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java index c8a2eb2bf6..bb56c9a13f 100644 --- a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java @@ -5,7 +5,8 @@ package org.opensearch.knn.quantization.factory; -import org.opensearch.knn.quantization.enums.QuantizationType; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.SQParams; import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; @@ -15,11 +16,9 @@ * The QuantizerRegistrar class is responsible for registering default quantizers. * This class ensures that the registration happens only once in a thread-safe manner. */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) final class QuantizerRegistrar { - // Private constructor to prevent instantiation - private QuantizerRegistrar() {} - /** * Registers default quantizers if not already registered. *

@@ -27,22 +26,12 @@ private QuantizerRegistrar() {} * even in a multi-threaded environment. *

*/ - public static synchronized void registerDefaultQuantizers() { + static synchronized void registerDefaultQuantizers() { // Register OneBitScalarQuantizer for SQParams with VALUE_QUANTIZATION and SQTypes.ONE_BIT - QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE, ScalarQuantizationType.ONE_BIT, OneBitScalarQuantizer::new); + QuantizerRegistry.register(new SQParams(ScalarQuantizationType.ONE_BIT).getTypeIdentifier(), OneBitScalarQuantizer::new); // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 2 - QuantizerRegistry.register( - SQParams.class, - QuantizationType.VALUE, - ScalarQuantizationType.TWO_BIT, - () -> new MultiBitScalarQuantizer(2) - ); + QuantizerRegistry.register(new SQParams(ScalarQuantizationType.TWO_BIT).getTypeIdentifier(), () -> new MultiBitScalarQuantizer(2)); // Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 4 - QuantizerRegistry.register( - SQParams.class, - QuantizationType.VALUE, - ScalarQuantizationType.FOUR_BIT, - () -> new MultiBitScalarQuantizer(4) - ); + QuantizerRegistry.register(new SQParams(ScalarQuantizationType.FOUR_BIT).getTypeIdentifier(), () -> new MultiBitScalarQuantizer(4)); } } diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java index 1243d79eff..260abad11c 100644 --- a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistry.java @@ -5,8 +5,8 @@ package org.opensearch.knn.quantization.factory; -import org.opensearch.knn.quantization.enums.QuantizationType; -import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.quantizer.Quantizer; @@ -19,32 +19,20 @@ * of quantizer instances. Quantizers are registered with specific quantization parameters * and type identifiers, allowing for efficient lookup and instantiation. */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) final class QuantizerRegistry { - - // Private constructor to prevent instantiation - private QuantizerRegistry() {} - // ConcurrentHashMap for thread-safe access private static final Map>> registry = new ConcurrentHashMap<>(); /** * Registers a quantizer with the registry. * - * @param paramClass the class of the quantization parameters - * @param quantizationType the quantization type (e.g., VALUE_QUANTIZATION) - * @param sqType the specific quantization subtype (e.g., ONE_BIT, TWO_BIT) + * @param paramIdentifier the unique identifier for the quantization parameters * @param quantizerSupplier a supplier that provides instances of the quantizer - * @param

the type of quantization parameters */ - public static

void register( - final Class

paramClass, - final QuantizationType quantizationType, - final ScalarQuantizationType sqType, - final Supplier> quantizerSupplier - ) { - String identifier = createIdentifier(quantizationType, sqType); + public static void register(final String paramIdentifier, final Supplier> quantizerSupplier) { // Ensure that the quantizer for this identifier is registered only once - registry.computeIfAbsent(identifier, key -> quantizerSupplier); + registry.computeIfAbsent(paramIdentifier, key -> quantizerSupplier); } /** @@ -60,23 +48,10 @@ public static

Quantizer getQuantizer(fin String identifier = params.getTypeIdentifier(); Supplier> supplier = registry.get(identifier); if (supplier == null) { - throw new IllegalArgumentException( - "No quantizer registered for type identifier: " + identifier + ". Available quantizers: " + registry.keySet() - ); + throw new IllegalArgumentException("No quantizer registered for type identifier: " + identifier); } @SuppressWarnings("unchecked") Quantizer quantizer = (Quantizer) supplier.get(); return quantizer; } - - /** - * Creates a unique identifier for the quantizer based on the quantization type and specific quantization subtype. - * - * @param quantizationType the quantization type - * @param sqType the specific quantization subtype - * @return a string identifier - */ - private static String createIdentifier(final QuantizationType quantizationType, final ScalarQuantizationType sqType) { - return quantizationType.name() + "_" + sqType.name(); - } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java index 18077182fc..e3b73b4cbf 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java @@ -5,27 +5,38 @@ package org.opensearch.knn.quantization.models.quantizationOutput; +import java.io.ByteArrayOutputStream; +import java.io.IOException; + /** * The BinaryQuantizationOutput class represents the output of a quantization process in binary format. * It implements the QuantizationOutput interface to handle byte arrays specifically. */ public class BinaryQuantizationOutput implements QuantizationOutput { - private final byte[] quantizedVector; + private final ByteArrayOutputStream byteArrayOutputStream; + + /** + * Constructs a BinaryQuantizationOutput instance with a default initial buffer size. + */ + public BinaryQuantizationOutput() { + this.byteArrayOutputStream = new ByteArrayOutputStream(); + } /** - * Constructs a BinaryQuantizationOutput instance with the specified quantized vector. + * Updates the quantized vector with a new byte array. * - * @param quantizedVector the quantized vector represented as a byte array. + * @param newQuantizedVector the new quantized vector represented as a byte array. */ - public BinaryQuantizationOutput(final byte[] quantizedVector) { - if (quantizedVector == null) { - throw new IllegalArgumentException("Quantized vector cannot be null"); + public void updateQuantizedVector(final byte[] newQuantizedVector) throws IOException { + if (newQuantizedVector == null || newQuantizedVector.length == 0) { + throw new IllegalArgumentException("Quantized vector cannot be null or empty"); } - this.quantizedVector = quantizedVector; + byteArrayOutputStream.reset(); + byteArrayOutputStream.write(newQuantizedVector); } @Override public byte[] getQuantizedVector() { - return quantizedVector; + return byteArrayOutputStream.toByteArray(); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java index c5c5fd21f6..8f01a05946 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/QuantizationOutput.java @@ -5,6 +5,8 @@ package org.opensearch.knn.quantization.models.quantizationOutput; +import java.io.IOException; + /** * The QuantizationOutput interface defines the contract for quantization output data. * @@ -17,4 +19,12 @@ public interface QuantizationOutput { * @return the quantized data. */ T getQuantizedVector(); + + /** + * Updates the quantized vector with new data. + * + * @param newQuantizedVector the new quantized vector data. + * @throws IOException if an I/O error occurs during the update. + */ + void updateQuantizedVector(T newQuantizedVector) throws IOException; } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java index 2c982a3064..e19a6b79d9 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java @@ -5,9 +5,10 @@ package org.opensearch.knn.quantization.models.quantizationParams; -import org.opensearch.knn.quantization.enums.QuantizationType; - -import java.io.Serializable; +import java.io.Externalizable; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; /** * Interface for quantization parameters. @@ -16,17 +17,7 @@ * Implementations of this interface are expected to provide specific configurations * for various quantization strategies. */ -public interface QuantizationParams extends Serializable { - - /** - * Gets the quantization type associated with the parameters. - * The quantization type defines the overall strategy or method used - * for quantization, such as VALUE_QUANTIZATION or SPACE_QUANTIZATION. - * - * @return the {@link QuantizationType} indicating the quantization method. - */ - QuantizationType getQuantizationType(); - +public interface QuantizationParams extends Externalizable { /** * Provides a unique identifier for the quantization parameters. * This identifier is typically a combination of the quantization type @@ -36,4 +27,29 @@ public interface QuantizationParams extends Serializable { * @return a string representing the unique type identifier. */ String getTypeIdentifier(); + + /** + * Serializes the QuantizationParams object to an external output. + * Default implementation is no-op. + * + * @param out the ObjectOutput to write the object to. + * @throws IOException if an I/O error occurs during serialization. + */ + @Override + default void writeExternal(ObjectOutput out) throws IOException { + // Default no-op implementation + } + + /** + * Deserializes the QuantizationParams object from an external input. + * Default implementation is no-op. + * + * @param in the ObjectInput to read the object from. + * @throws IOException if an I/O error occurs during deserialization. + * @throws ClassNotFoundException if the class of the serialized object cannot be found. + */ + @Override + default void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + // Default no-op implementation + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java index 0b6bbc9885..22a7c379ff 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/SQParams.java @@ -5,36 +5,27 @@ package org.opensearch.knn.quantization.models.quantizationParams; -import org.opensearch.knn.quantization.enums.QuantizationType; +import lombok.*; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; -import java.util.Objects; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.Locale; /** * The SQParams class represents the parameters specific to scalar quantization (SQ). * This class implements the QuantizationParams interface and includes the type of scalar quantization. */ +@Getter +@AllArgsConstructor +@NoArgsConstructor +@EqualsAndHashCode +@ToString public class SQParams implements QuantizationParams { - private final ScalarQuantizationType sqType; - - /** - * Constructs an SQParams instance with the specified scalar quantization type. - * - * @param sqType The specific type of scalar quantization (e.g., ONE_BIT, TWO_BIT, FOUR_BIT). - */ - public SQParams(final ScalarQuantizationType sqType) { - this.sqType = sqType; - } - - /** - * Returns the quantization type associated with these parameters. - * - * @return The quantization type, always VALUE_QUANTIZATION for SQParams. - */ - @Override - public QuantizationType getQuantizationType() { - return QuantizationType.VALUE; - } + private ScalarQuantizationType sqType; + private static final long serialVersionUID = 1L; // Version ID for serialization + private static final int CURRENT_VERSION = 1; // Current version of SQParams /** * Returns the scalar quantization type. @@ -46,38 +37,47 @@ public ScalarQuantizationType getSqType() { } /** - * Provides a unique type identifier for the SQParams, combining the quantization type and SQ type. + * Provides a unique type identifier for the SQParams, combining the SQ type. * This identifier is useful for distinguishing between different configurations of scalar quantization parameters. * * @return A string representing the unique type identifier. */ @Override public String getTypeIdentifier() { - return getQuantizationType().name() + "_" + sqType.name(); + return String.format(Locale.ROOT, "SQParams_%d", sqType.getId()); } /** - * Compares this object to the specified object. The result is true if and only if the argument is not null and is - * an SQParams object that represents the same scalar quantization type. + * Serializes the SQParams object to an external output. + * This method writes the scalar quantization type to the output stream. * - * @param o The object to compare this SQParams against. - * @return true if the given object represents an SQParams equivalent to this instance, false otherwise. + * @param out the ObjectOutput to write the object to. + * @throws IOException if an I/O error occurs during serialization. */ @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - SQParams sqParams = (SQParams) o; - return sqType == sqParams.sqType; + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(sqType); } /** - * Returns a hash code value for this SQParams instance. + * Deserializes the SQParams object from an external input with versioning. + * This method reads the scalar quantization type and new field from the input stream based on the version. * - * @return A hash code value for this SQParams instance. + * @param in the ObjectInput to read the object from. + * @param version the version of the serialized object. + * @throws IOException if an I/O error occurs during deserialization. + * @throws ClassNotFoundException if the class of the serialized object cannot be found. */ - @Override - public int hashCode() { - return Objects.hash(sqType); + public void readExternal(ObjectInput in, int version) throws IOException, ClassNotFoundException { + sqType = (ScalarQuantizationType) in.readObject(); + } + + /** + * Returns the current version of the SQParams class. + * This version is used for serialization and deserialization purposes. + * @return + */ + public int getVersion() { + return CURRENT_VERSION; } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java index acc8c2f009..4604b472ff 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java @@ -5,28 +5,27 @@ package org.opensearch.knn.quantization.models.quantizationState; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationParams.SQParams; -import org.opensearch.knn.quantization.util.QuantizationStateSerializer; import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; /** * DefaultQuantizationState is used as a fallback state when no training is required or if training fails. * It can be utilized by any quantizer to represent a default state. */ +@Getter +@NoArgsConstructor // No-argument constructor for deserialization +@AllArgsConstructor // Constructor with all arguments public class DefaultQuantizationState implements QuantizationState { - - private final QuantizationParams params; - - /** - * Constructs a DefaultQuantizationState with the given quantization parameters. - * - * @param params the quantization parameters. - */ - public DefaultQuantizationState(final QuantizationParams params) { - this.params = params; - } + private QuantizationParams params; + private static final long serialVersionUID = 1L; // Version ID for serialization + private static final int CURRENT_VERSION = 1; // Current version of SQParams /** * Returns the quantization parameters associated with this state. @@ -60,7 +59,33 @@ public byte[] toByteArray() throws IOException { public static DefaultQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { return (DefaultQuantizationState) QuantizationStateSerializer.deserialize( bytes, + new DefaultQuantizationState(), (parentParams, specificData) -> new DefaultQuantizationState((SQParams) parentParams) ); } + + /** + * Writes the object to the output stream. + * This method is part of the Externalizable interface and is used to serialize the object. + * + * @param out the output stream to write the object to. + * @throws IOException if an I/O error occurs. + */ + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(params); + } + + /** + * Reads the object from the input stream. + * This method is part of the Externalizable interface and is used to deserialize the object. + * + * @param in the input stream to read the object from. + * @throws IOException if an I/O error occurs. + * @throws ClassNotFoundException if the class of the serialized object cannot be found. + */ + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this.params = (QuantizationParams) in.readObject(); + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java index 58834dd2c9..906562379a 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java @@ -5,43 +5,78 @@ package org.opensearch.knn.quantization.models.quantizationState; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; import org.opensearch.knn.quantization.models.quantizationParams.SQParams; -import org.opensearch.knn.quantization.util.QuantizationStateSerializer; import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; /** * MultiBitScalarQuantizationState represents the state of multi-bit scalar quantization, * including the thresholds used for quantization. */ +@Getter +@NoArgsConstructor // No-argument constructor for deserialization +@AllArgsConstructor // Constructor with all arguments public final class MultiBitScalarQuantizationState implements QuantizationState { - private final SQParams quantizationParams; - private final float[][] thresholds; - + private SQParams quantizationParams; /** - * Constructs a MultiBitScalarQuantizationState with the given quantization parameters and thresholds. + * The threshold values for multi-bit quantization, organized as a 2D array + * where each row corresponds to a different bit level. + * + * For example: + * - For 2-bit quantization: + * thresholds[0] -> {0.5f, 1.5f, 2.5f} // Thresholds for the first bit level + * thresholds[1] -> {1.0f, 2.0f, 3.0f} // Thresholds for the second bit level + * - For 4-bit quantization: + * thresholds[0] -> {0.1f, 0.2f, 0.3f} + * thresholds[1] -> {0.4f, 0.5f, 0.6f} + * thresholds[2] -> {0.7f, 0.8f, 0.9f} + * thresholds[3] -> {1.0f, 1.1f, 1.2f} * - * @param quantizationParams the scalar quantization parameters. - * @param thresholds the threshold values for multi-bit quantization, organized as a 2D array - * where each row corresponds to a different bit level. + * Each column represents the threshold for a specific dimension in the vector space. */ - public MultiBitScalarQuantizationState(final SQParams quantizationParams, final float[][] thresholds) { - this.quantizationParams = quantizationParams; - this.thresholds = thresholds; - } + private float[][] thresholds; + private static final long serialVersionUID = 1L; // Version ID for serialization + private static final int CURRENT_VERSION = 1; // Current version of SQParams @Override public SQParams getQuantizationParams() { return quantizationParams; } - /** - * Returns the thresholds used in the quantization process. - * - * @return a 2D array of threshold values. - */ - public float[][] getThresholds() { - return thresholds; + @Override + public void writeExternal(ObjectOutput out) throws IOException { + int combinedVersion = (CURRENT_VERSION << 16) | quantizationParams.getVersion(); + out.writeInt(combinedVersion); // Write the version + quantizationParams.writeExternal(out); + out.writeInt(thresholds.length); + out.writeInt(thresholds[0].length); + for (float[] row : thresholds) { + for (float value : row) { + out.writeFloat(value); + } + } + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + int combinedVersion = in.readInt(); + int stateVersion = (combinedVersion >> 16) & 0xFFFF; + int paramsVersion = combinedVersion & 0xFFFF; + quantizationParams = new SQParams(); + quantizationParams.readExternal(in, paramsVersion); // Use readExternal of SQParams + int rows = in.readInt(); + int cols = in.readInt(); + thresholds = new float[rows][cols]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + thresholds[i][j] = in.readFloat(); + } + } } @Override @@ -52,6 +87,7 @@ public byte[] toByteArray() throws IOException { public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize( bytes, + new MultiBitScalarQuantizationState(), (parentParams, thresholds) -> new MultiBitScalarQuantizationState((SQParams) parentParams, (float[][]) thresholds) ); } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java index 9b4bad56a4..59d298f910 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -5,42 +5,56 @@ package org.opensearch.knn.quantization.models.quantizationState; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; import org.opensearch.knn.quantization.models.quantizationParams.SQParams; -import org.opensearch.knn.quantization.util.QuantizationStateSerializer; import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; /** * OneBitScalarQuantizationState represents the state of one-bit scalar quantization, * including the mean values used for quantization. */ +@Getter +@NoArgsConstructor // No-argument constructor for deserialization +@AllArgsConstructor // Constructor with all arguments public final class OneBitScalarQuantizationState implements QuantizationState { - private final SQParams quantizationParams; - private final float[] meanThresholds; - - /** - * Constructs a OneBitScalarQuantizationState with the given quantization parameters and mean values. - * - * @param quantizationParams the scalar quantization parameters. - * @param mean the mean values for each dimension. - */ - public OneBitScalarQuantizationState(final SQParams quantizationParams, final float[] mean) { - this.quantizationParams = quantizationParams; - this.meanThresholds = mean; - } + private SQParams quantizationParams; + private float[] meanThresholds; + private static final long serialVersionUID = 1L; // Version ID for serialization + private static final int CURRENT_VERSION = 1; // Current version of SQParams @Override public SQParams getQuantizationParams() { return quantizationParams; } - /** - * Returns the mean values used in the quantization process. - * - * @return an array of mean values. - */ - public float[] getMeanThresholds() { - return meanThresholds; + @Override + public void writeExternal(ObjectOutput out) throws IOException { + int combinedVersion = (CURRENT_VERSION << 16) | quantizationParams.getVersion(); + out.writeInt(combinedVersion); // Write the version + quantizationParams.writeExternal(out); + out.writeInt(meanThresholds.length); + for (float mean : meanThresholds) { + out.writeFloat(mean); + } + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + int combinedVersion = in.readInt(); // Read the combined version + int stateVersion = (combinedVersion >> 16) & 0xFFFF; + int paramsVersion = combinedVersion & 0xFFFF; + quantizationParams = new SQParams(); + quantizationParams.readExternal(in, paramsVersion); // Use readExternal of SQParams + int length = in.readInt(); + meanThresholds = new float[length]; + for (int i = 0; i < length; i++) { + meanThresholds[i] = in.readFloat(); + } } @Override @@ -51,6 +65,7 @@ public byte[] toByteArray() throws IOException { public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize( bytes, + new OneBitScalarQuantizationState(), (parentParams, meanThresholds) -> new OneBitScalarQuantizationState((SQParams) parentParams, (float[]) meanThresholds) ); } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java index c17ff0641b..d3778fe299 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java @@ -7,14 +7,14 @@ import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import java.io.Externalizable; import java.io.IOException; -import java.io.Serializable; /** * QuantizationState interface represents the state of a quantization process, including the parameters used. * This interface provides methods for serializing and deserializing the state. */ -public interface QuantizationState extends Serializable { +public interface QuantizationState extends Externalizable { /** * Returns the quantization parameters associated with this state. * diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java new file mode 100644 index 0000000000..5f00d8e0ca --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.experimental.UtilityClass; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; + +import java.io.ByteArrayOutputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.io.IOException; +import java.io.ByteArrayInputStream; +import java.io.ObjectInputStream; + +/** + * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing + * QuantizationState objects along with their specific data. + */ +@UtilityClass +class QuantizationStateSerializer { + + /** + * A functional interface for deserializing specific data associated with a QuantizationState. + */ + @FunctionalInterface + interface SerializableDeserializer { + QuantizationState deserialize(QuantizationParams parentParams, Serializable specificData); + } + + /** + * Serializes the QuantizationState and specific data into a byte array. + * + * @param state The QuantizationState to serialize. + * @param specificData The specific data related to the state, to be serialized. + * @return A byte array representing the serialized state and specific data. + * @throws IOException If an I/O error occurs during serialization. + */ + static byte[] serialize(QuantizationState state, Serializable specificData) throws IOException { + try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) { + state.writeExternal(out); + out.writeObject(specificData); + out.flush(); + return bos.toByteArray(); + } + } + + /** + * Deserializes a QuantizationState and its specific data from a byte array. + * + * @param bytes The byte array containing the serialized data. + * @param stateInstance An instance of the state to call readExternal on. + * @param specificDataDeserializer The deserializer for the specific data associated with the state. + * @return The deserialized QuantizationState including its specific data. + * @throws IOException If an I/O error occurs during deserialization. + * @throws ClassNotFoundException If the class of the serialized object cannot be found. + */ + static QuantizationState deserialize(byte[] bytes, QuantizationState stateInstance, SerializableDeserializer specificDataDeserializer) + throws IOException, ClassNotFoundException { + try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) { + stateInstance.readExternal(in); + Serializable specificData = (Serializable) in.readObject(); // Read the specific data + return specificDataDeserializer.deserialize(stateInstance.getQuantizationParams(), specificData); + } + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java index 14689ea476..54ebe311c6 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java +++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java @@ -5,64 +5,21 @@ package org.opensearch.knn.quantization.models.requests; -import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import lombok.AllArgsConstructor; +import lombok.Getter; /** * TrainingRequest represents a request for training a quantizer. * * @param the type of vectors to be trained. */ +@Getter +@AllArgsConstructor public abstract class TrainingRequest { - private final QuantizationParams params; - private final int totalNumberOfVectors; - private int[] sampledIndices; - - /** - * Constructs a TrainingRequest with the given parameters and total number of vectors. - * - * @param params the quantization parameters. - * @param totalNumberOfVectors the total number of vectors. - */ - protected TrainingRequest(final QuantizationParams params, final int totalNumberOfVectors) { - this.params = params; - this.totalNumberOfVectors = totalNumberOfVectors; - } - - /** - * Returns the quantization parameters. - * - * @return the quantization parameters. - */ - public QuantizationParams getParams() { - return params; - } - /** - * Returns the total number of vectors. - * - * @return the total number of vectors. - */ - public int getTotalNumberOfVectors() { - return totalNumberOfVectors; - } - - /** - * Sets the sampled indices for this training request. - * - * @param sampledIndices the sampled indices. + * The total number of vectors in one segment. */ - public void setSampledIndices(int[] sampledIndices) { - this.sampledIndices = sampledIndices; - } - - /** - * Returns the sampled indices for this training request. - * - * @return the sampled indices. - */ - public int[] getSampledIndices() { - return sampledIndices; - } + private final int totalNumberOfVectors; /** * Returns the vector corresponding to the specified document ID. diff --git a/src/main/java/org/opensearch/knn/quantization/util/BitPacker.java b/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java similarity index 77% rename from src/main/java/org/opensearch/knn/quantization/util/BitPacker.java rename to src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java index 5d99a892fb..be95dc629b 100644 --- a/src/main/java/org/opensearch/knn/quantization/util/BitPacker.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java @@ -1,10 +1,9 @@ /* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * */ -package org.opensearch.knn.quantization.util; +package org.opensearch.knn.quantization.quantizer; import lombok.experimental.UtilityClass; @@ -15,7 +14,7 @@ * Provides methods for packing arrays of bits into byte arrays for efficient storage or transmission. */ @UtilityClass -public class BitPacker { +class BitPacker { /** * Packs the list of bit arrays into a single byte array. @@ -25,14 +24,13 @@ public class BitPacker { * @return a byte array containing the packed bits. * @throws IllegalArgumentException if the bitArrays list is empty, if any bit array is null, or if bit arrays have inconsistent lengths. */ - public static byte[] packBits(List bitArrays) { - if (bitArrays.isEmpty()) { - throw new IllegalArgumentException("The list of bit arrays cannot be empty."); + static byte[] packBits(List bitArrays) { + if (bitArrays == null || bitArrays.isEmpty()) { + throw new IllegalArgumentException("The list of bit arrays cannot be null or empty."); } - int bitArrayLength = bitArrays.get(0).length; int bitLength = bitArrays.size() * bitArrayLength; - int byteLength = (bitLength + 7) / 8; + int byteLength = (bitLength + 7) >> 3; // Using bit shift instead of division by 8 byte[] packedArray = new byte[byteLength]; int bitPosition = 0; @@ -45,8 +43,8 @@ public static byte[] packBits(List bitArrays) { } for (byte bit : bitArray) { - int byteIndex = bitPosition / 8; - int bitIndex = 7 - (bitPosition % 8); + int byteIndex = bitPosition >> 3; // Using bit shift instead of division by 8 + int bitIndex = 7 - (bitPosition & 7); // Using bitwise AND instead of modulo by 8 if (bit == 1) { packedArray[byteIndex] |= (1 << bitIndex); } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java index 0143a66147..3d3530f174 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -6,18 +6,19 @@ package org.opensearch.knn.quantization.quantizer; -import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.SQParams; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplerType; import org.opensearch.knn.quantization.sampler.SamplingFactory; -import org.opensearch.knn.quantization.util.BitPacker; -import org.opensearch.knn.quantization.util.QuantizerHelper; +import java.io.IOException; import java.util.ArrayList; +import java.util.BitSet; import java.util.List; /** @@ -41,7 +42,7 @@ public class MultiBitScalarQuantizer implements Quantizer { * @param bitsPerCoordinate the number of bits used per coordinate for quantization. */ public MultiBitScalarQuantizer(final int bitsPerCoordinate) { - this(bitsPerCoordinate, DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR)); + this(bitsPerCoordinate, DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplerType.RESERVOIR)); } /** @@ -70,15 +71,16 @@ public MultiBitScalarQuantizer(final int bitsPerCoordinate, final int samplingSi */ @Override public QuantizationState train(final TrainingRequest trainingRequest) { - SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest); - int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); - - int dimension = trainingRequest.getVectorByDocId(sampledIndices[0]).length; + BitSet sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + int dimension = trainingRequest.getVectorByDocId(sampledIndices.nextSetBit(0)).length; float[] meanArray = new float[dimension]; float[] stdDevArray = new float[dimension]; // Calculate sum, mean, and standard deviation in one pass QuantizerHelper.calculateSumMeanAndStdDev(trainingRequest, sampledIndices, meanArray, stdDevArray); float[][] thresholds = calculateThresholds(meanArray, stdDevArray, dimension); + SQParams params = (bitsPerCoordinate == 2) + ? new SQParams(ScalarQuantizationType.TWO_BIT) + : new SQParams(ScalarQuantizationType.FOUR_BIT); return new MultiBitScalarQuantizationState(params, thresholds); } @@ -88,10 +90,11 @@ public QuantizationState train(final TrainingRequest trainingRequest) { * * @param vector the vector to quantize. * @param state the quantization state containing threshold information. - * @return a BinaryQuantizationOutput containing the quantized data. + * @param output the QuantizationOutput object to store the quantized representation of the vector. + * @throws IOException if an I/O error occurs during quantization. */ @Override - public QuantizationOutput quantize(final float[] vector, final QuantizationState state) { + public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput output) throws IOException { if (vector == null) { throw new IllegalArgumentException("Vector to quantize must not be null."); } @@ -111,7 +114,7 @@ public QuantizationOutput quantize(final float[] vector, final Quantizat bitArrays.add(bitArray); } - return new BinaryQuantizationOutput(BitPacker.packBits(bitArrays)); + output.updateQuantizedVector(BitPacker.packBits(bitArrays)); } /** diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java index 2eaa07ce03..5f87e83ba9 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -5,17 +5,18 @@ package org.opensearch.knn.quantization.quantizer; -import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.SQParams; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplerType; import org.opensearch.knn.quantization.sampler.SamplingFactory; -import org.opensearch.knn.quantization.util.BitPacker; -import org.opensearch.knn.quantization.util.QuantizerHelper; +import java.io.IOException; +import java.util.BitSet; import java.util.Collections; /** @@ -37,7 +38,7 @@ public class OneBitScalarQuantizer implements Quantizer { * Constructs a OneBitScalarQuantizer with a default sampling size of 25000. */ public OneBitScalarQuantizer() { - this(DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR)); + this(DEFAULT_SAMPLE_SIZE, SamplingFactory.getSampler(SamplerType.RESERVOIR)); } /** @@ -49,7 +50,6 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) { this.samplingSize = samplingSize; this.sampler = sampler; - ; } /** @@ -61,10 +61,9 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) { */ @Override public QuantizationState train(final TrainingRequest trainingRequest) { - SQParams params = QuantizerHelper.validateAndExtractParams(trainingRequest); - int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); - float[] mean = QuantizerHelper.calculateMean(trainingRequest, sampledIndices); - return new OneBitScalarQuantizationState(params, mean); + BitSet sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + float[] mean = QuantizerHelper.calculateMean(trainingRequest, sampledDocIds); + return new OneBitScalarQuantizationState(new SQParams(ScalarQuantizationType.ONE_BIT), mean); } /** @@ -73,10 +72,11 @@ public QuantizationState train(final TrainingRequest trainingRequest) { * * @param vector the vector to quantize. * @param state the quantization state containing the means for each dimension. + * @param output the QuantizationOutput object to store the quantized representation of the vector. * @return a BinaryQuantizationOutput containing the quantized data. */ @Override - public QuantizationOutput quantize(final float[] vector, final QuantizationState state) { + public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput output) throws IOException { if (vector == null) { throw new IllegalArgumentException("Vector to quantize must not be null."); } @@ -90,7 +90,7 @@ public QuantizationOutput quantize(final float[] vector, final Quantizat for (int i = 0; i < vector.length; i++) { quantizedVector[i] = (byte) (vector[i] > thresholds[i] ? 1 : 0); } - return new BinaryQuantizationOutput(BitPacker.packBits(Collections.singletonList(quantizedVector))); + output.updateQuantizedVector(BitPacker.packBits(Collections.singletonList(quantizedVector))); } /** diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java index 8231a8aa27..beabd8d738 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java @@ -1,14 +1,11 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - package org.opensearch.knn.quantization.quantizer; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import java.io.IOException; + /** * The Quantizer interface defines the methods required for training and quantizing vectors * in the context of K-Nearest Neighbors (KNN) and similar machine learning tasks. @@ -34,7 +31,7 @@ public interface Quantizer { * * @param vector the vector to quantize. * @param state the quantization state containing parameters for quantization. - * @return a QuantizationOutput containing the quantized representation of the vector. + * @param output the QuantizationOutput object to store the quantized representation of the vector. */ - QuantizationOutput quantize(T vector, QuantizationState state); + void quantize(T vector, QuantizationState state, QuantizationOutput output) throws IOException; } diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java similarity index 67% rename from src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java rename to src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java index adc4e34c43..c017d9eb4b 100644 --- a/src/main/java/org/opensearch/knn/quantization/util/QuantizerHelper.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java @@ -1,40 +1,22 @@ /* * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * */ -package org.opensearch.knn.quantization.util; +package org.opensearch.knn.quantization.quantizer; -import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; -import org.opensearch.knn.quantization.models.quantizationParams.SQParams; import org.opensearch.knn.quantization.models.requests.TrainingRequest; import lombok.experimental.UtilityClass; +import java.util.BitSet; + /** * Utility class providing common methods for quantizer operations, such as parameter validation and * extraction. This class is designed to be used with various quantizer implementations that require * consistent handling of training requests and sampled indices. */ @UtilityClass -public class QuantizerHelper { - - /** - * Validates the provided training request to ensure it contains non-null quantization parameters. - * Extracts and returns the SQParams from the training request. - * - * @param trainingRequest the training request to validate and extract parameters from. - * @return the extracted SQParams. - * @throws IllegalArgumentException if the SQParams are null. - */ - public static SQParams validateAndExtractParams(TrainingRequest trainingRequest) { - QuantizationParams params = trainingRequest.getParams(); - if (params == null || !(params instanceof SQParams)) { - throw new IllegalArgumentException("Quantization parameters must not be null and must be of type SQParams."); - } - return (SQParams) params; - } - +class QuantizerHelper { /** * Calculates the mean vector from a set of sampled vectors. * @@ -49,13 +31,13 @@ public static SQParams validateAndExtractParams(TrainingRequest trainingReque * @throws IllegalArgumentException If any of the vectors at the sampled indices are null. * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors. */ - public static float[] calculateMean(TrainingRequest samplingRequest, int[] sampledIndices) { - int totalSamples = sampledIndices.length; + static float[] calculateMean(TrainingRequest samplingRequest, BitSet sampledIndices) { + int totalSamples = sampledIndices.cardinality(); float[] mean = null; - for (int index : sampledIndices) { - float[] vector = samplingRequest.getVectorByDocId(index); + for (int docId = sampledIndices.nextSetBit(0); docId >= 0; docId = sampledIndices.nextSetBit(docId + 1)) { + float[] vector = samplingRequest.getVectorByDocId(docId); if (vector == null) { - throw new IllegalArgumentException("Vector at sampled index " + index + " is null."); + throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); } if (mean == null) { mean = new float[vector.length]; @@ -81,20 +63,20 @@ public static float[] calculateMean(TrainingRequest samplingRequest, in * @param meanArray the array to store the sum and then the mean of each dimension. * @param stdDevArray the array to store the sum of squares and then the standard deviation of each dimension. */ - public static void calculateSumMeanAndStdDev( + static void calculateSumMeanAndStdDev( TrainingRequest trainingRequest, - int[] sampledIndices, + BitSet sampledIndices, float[] meanArray, float[] stdDevArray ) { - int totalSamples = sampledIndices.length; + int totalSamples = sampledIndices.cardinality(); int dimension = meanArray.length; // Single pass to calculate sum and sum of squares - for (int index : sampledIndices) { - float[] vector = trainingRequest.getVectorByDocId(index); + for (int docId = sampledIndices.nextSetBit(0); docId >= 0; docId = sampledIndices.nextSetBit(docId + 1)) { + float[] vector = trainingRequest.getVectorByDocId(docId); if (vector == null) { - throw new IllegalArgumentException("Vector at sampled index " + index + " is null."); + throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); } for (int j = 0; j < dimension; j++) { meanArray[j] += vector[j]; diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java index da5327def2..7bf7d30384 100644 --- a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java +++ b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java @@ -5,10 +5,9 @@ package org.opensearch.knn.quantization.sampler; -import java.util.Arrays; +import java.util.BitSet; import java.util.Random; import java.util.concurrent.ThreadLocalRandom; -import java.util.stream.IntStream; /** * ReservoirSampler implements the Sampler interface and provides a method for sampling @@ -24,25 +23,7 @@ final class ReservoirSampler implements Sampler { * Constructs a ReservoirSampler with a new Random instance. */ public ReservoirSampler() { - this(ThreadLocalRandom.current()); - } - - /** - * Constructs a ReservoirSampler with a specified random seed for reproducibility. - * - * @param seed the seed for the random number generator. - */ - public ReservoirSampler(final long seed) { - this(new Random(seed)); - } - - /** - * Constructs a ReservoirSampler with a specified Random instance. - * - * @param random the Random instance for generating random numbers. - */ - public ReservoirSampler(final Random random) { - this.random = random; + this.random = ThreadLocalRandom.current(); } /** @@ -55,9 +36,11 @@ public ReservoirSampler(final Random random) { * @return an array of sampled indices. */ @Override - public int[] sample(final int totalNumberOfVectors, final int sampleSize) { + public BitSet sample(final int totalNumberOfVectors, final int sampleSize) { if (totalNumberOfVectors <= sampleSize) { - return IntStream.range(0, totalNumberOfVectors).toArray(); + BitSet bitSet = new BitSet(totalNumberOfVectors); + bitSet.set(0, totalNumberOfVectors); + return bitSet; } return reservoirSampleIndices(totalNumberOfVectors, sampleSize); } @@ -67,19 +50,30 @@ public int[] sample(final int totalNumberOfVectors, final int sampleSize) { * This method ensures that each index in the range [0, numVectors) has an equal probability * of being included in the sample. * + * Reservoir sampling is particularly useful for selecting a random sample from a large or unknown-sized dataset. + * For more information on the algorithm, see the following link: + * Reservoir Sampling - Wikipedia + * * @param numVectors the total number of vectors. * @param sampleSize the number of indices to sample. - * @return an array of sampled indices. + * @return a BitSet representing the sampled indices. */ - private int[] reservoirSampleIndices(final int numVectors, final int sampleSize) { - int[] indices = IntStream.range(0, sampleSize).toArray(); + private BitSet reservoirSampleIndices(final int numVectors, final int sampleSize) { + int[] indices = new int[sampleSize]; + for (int i = 0; i < sampleSize; i++) { + indices[i] = i; + } for (int i = sampleSize; i < numVectors; i++) { int j = random.nextInt(i + 1); if (j < sampleSize) { indices[j] = i; } } - Arrays.sort(indices); - return indices; + // Using BitSet to track the presence of indices + BitSet bitSet = new BitSet(numVectors); + for (int i = 0; i < sampleSize; i++) { + bitSet.set(indices[i]); + } + return bitSet; } } diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java index 9021073b4e..17834cf043 100644 --- a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java +++ b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java @@ -5,6 +5,23 @@ package org.opensearch.knn.quantization.sampler; +import java.util.BitSet; + +/** + * The Sampler interface defines the contract for sampling strategies + * used in various quantization processes. Implementations of this + * interface should provide specific strategies for selecting a sample + * from a given set of vectors. + */ public interface Sampler { - int[] sample(int totalNumberOfVectors, int sampleSize); + + /** + * Samples a subset of indices from the total number of vectors. + * + * @param totalNumberOfVectors the total number of vectors available. + * @param sampleSize the number of vectors to be sampled. + * @return an array of integers representing the indices of the sampled vectors. + * @throws IllegalArgumentException if the sample size is greater than the total number of vectors. + */ + BitSet sample(int totalNumberOfVectors, int sampleSize); } diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java new file mode 100644 index 0000000000..cd9b301dff --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplerType.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.sampler; + +/** + * SamplerType is an enumeration of the different types of samplers that can be created by the factory. + */ +public enum SamplerType { + RESERVOIR, // Represents a reservoir sampling strategy + // Add more enum values here for additional sampler types +} diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java index be228fe6f9..1545a127df 100644 --- a/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java +++ b/src/main/java/org/opensearch/knn/quantization/sampler/SamplingFactory.java @@ -5,28 +5,16 @@ package org.opensearch.knn.quantization.sampler; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; + /** * SamplingFactory is a factory class for creating instances of Sampler. * It uses the factory design pattern to encapsulate the creation logic for different types of samplers. */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) public final class SamplingFactory { - /** - * Private constructor to prevent instantiation of this class. - * The class is not meant to be instantiated, as it provides static methods only. - */ - private SamplingFactory() { - - } - - /** - * SamplerType is an enumeration of the different types of samplers that can be created by the factory. - */ - public enum SamplerType { - RESERVOIR, // Represents a reservoir sampling strategy - // Add more enum values here for additional sampler types - } - /** * Creates and returns a Sampler instance based on the specified SamplerType. * diff --git a/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java deleted file mode 100644 index 89b3b67bd8..0000000000 --- a/src/main/java/org/opensearch/knn/quantization/util/QuantizationStateSerializer.java +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.util; - -import lombok.experimental.UtilityClass; -import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; -import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; - -import java.io.ByteArrayOutputStream; -import java.io.ObjectOutputStream; -import java.io.Serializable; -import java.io.IOException; -import java.io.ByteArrayInputStream; -import java.io.ObjectInputStream; - -/** - * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing - * QuantizationState objects along with their specific data. - */ -@UtilityClass -public class QuantizationStateSerializer { - - /** - * A functional interface for deserializing specific data associated with a QuantizationState. - */ - @FunctionalInterface - public interface SerializableDeserializer { - QuantizationState deserialize(QuantizationParams parentParams, Serializable specificData); - } - - /** - * Serializes the QuantizationState and specific data into a byte array. - * - * @param state The QuantizationState to serialize. - * @param specificData The specific data related to the state, to be serialized. - * @return A byte array representing the serialized state and specific data. - * @throws IOException If an I/O error occurs during serialization. - */ - public static byte[] serialize(QuantizationState state, Serializable specificData) throws IOException { - byte[] parentBytes = serializeParentParams(state.getQuantizationParams()); - try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) { - out.writeInt(parentBytes.length); // Write the length of the parent bytes - out.write(parentBytes); // Write the parent bytes - out.writeObject(specificData); // Write the specific data - out.flush(); - return bos.toByteArray(); - } - } - - /** - * Deserializes a QuantizationState and its specific data from a byte array. - * - * @param bytes The byte array containing the serialized data. - * @param specificDataDeserializer The deserializer for the specific data associated with the state. - * @return The deserialized QuantizationState including its specific data. - * @throws IOException If an I/O error occurs during deserialization. - * @throws ClassNotFoundException If the class of the serialized object cannot be found. - */ - public static QuantizationState deserialize(byte[] bytes, SerializableDeserializer specificDataDeserializer) throws IOException, - ClassNotFoundException { - try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) { - int parentLength = in.readInt(); - // Read the length of the parent bytes - byte[] parentBytes = new byte[parentLength]; - in.readFully(parentBytes); // Read the parent bytes - QuantizationParams parentParams = deserializeParentParams(parentBytes); // Deserialize the parent params - Serializable specificData = (Serializable) in.readObject(); // Read the specific data - return specificDataDeserializer.deserialize(parentParams, specificData); - } - } - - /** - * Serializes the parent parameters of the QuantizationState into a byte array. - * - * @param params The QuantizationParams to serialize. - * @return A byte array representing the serialized parent parameters. - * @throws IOException If an I/O error occurs during serialization. - */ - private static byte[] serializeParentParams(QuantizationParams params) throws IOException { - try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) { - out.writeObject(params); - out.flush(); - return bos.toByteArray(); - } - } - - /** - * Deserializes the parent parameters of the QuantizationState from a byte array. - * - * @param bytes The byte array containing the serialized parent parameters. - * @return The deserialized QuantizationParams. - * @throws IOException If an I/O error occurs during deserialization. - * @throws ClassNotFoundException If the class of the serialized object cannot be found. - */ - private static QuantizationParams deserializeParentParams(byte[] bytes) throws IOException, ClassNotFoundException { - try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) { - return (QuantizationParams) in.readObject(); - } - } -} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java deleted file mode 100644 index 598a07867c..0000000000 --- a/src/test/java/org/opensearch/knn/quantization/enums/QuantizationTypeTests.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.enums; - -import org.opensearch.knn.KNNTestCase; - -public class QuantizationTypeTests extends KNNTestCase { - - public void testQuantizationTypeValues() { - QuantizationType[] expectedValues = { QuantizationType.SPACE, QuantizationType.VALUE }; - assertArrayEquals(expectedValues, QuantizationType.values()); - } - - public void testQuantizationTypeValueOf() { - assertEquals(QuantizationType.SPACE, QuantizationType.valueOf("SPACE")); - assertEquals(QuantizationType.VALUE, QuantizationType.valueOf("VALUE")); - } -} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java b/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java index 3498114a67..d55788cf87 100644 --- a/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java +++ b/src/test/java/org/opensearch/knn/quantization/enums/SQTypesTests.java @@ -12,8 +12,7 @@ public void testSQTypesValues() { ScalarQuantizationType[] expectedValues = { ScalarQuantizationType.ONE_BIT, ScalarQuantizationType.TWO_BIT, - ScalarQuantizationType.FOUR_BIT, - ScalarQuantizationType.UNSUPPORTED_TYPE }; + ScalarQuantizationType.FOUR_BIT }; assertArrayEquals(expectedValues, ScalarQuantizationType.values()); } @@ -21,6 +20,5 @@ public void testSQTypesValueOf() { assertEquals(ScalarQuantizationType.ONE_BIT, ScalarQuantizationType.valueOf("ONE_BIT")); assertEquals(ScalarQuantizationType.TWO_BIT, ScalarQuantizationType.valueOf("TWO_BIT")); assertEquals(ScalarQuantizationType.FOUR_BIT, ScalarQuantizationType.valueOf("FOUR_BIT")); - assertEquals(ScalarQuantizationType.UNSUPPORTED_TYPE, ScalarQuantizationType.valueOf("UNSUPPORTED_TYPE")); } } diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java deleted file mode 100644 index 3da665630e..0000000000 --- a/src/test/java/org/opensearch/knn/quantization/enums/ValueQuantizationTypeTests.java +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.enums; - -import org.opensearch.knn.KNNTestCase; - -public class ValueQuantizationTypeTests extends KNNTestCase { - public void testValueQuantizationTypeValues() { - ValueQuantizationType[] expectedValues = { ValueQuantizationType.SCALAR }; - assertArrayEquals(expectedValues, ValueQuantizationType.values()); - } - - public void testValueQuantizationTypeValueOf() { - assertEquals(ValueQuantizationType.SCALAR, ValueQuantizationType.valueOf("SCALAR")); - } -} diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java index 42fc18eba7..31348e0c6d 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java @@ -52,16 +52,6 @@ public void testGetQuantizer_withFourBitSQParams() { assertTrue(quantizer instanceof MultiBitScalarQuantizer); } - public void testGetQuantizer_withUnsupportedType() { - SQParams params = new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE); - try { - QuantizerFactory.getQuantizer(params); - fail("Expected IllegalArgumentException"); - } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().contains("No quantizer registered for type identifier")); - } - } - public void testGetQuantizer_withNullParams() { try { QuantizerFactory.getQuantizer(null); diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java index 7f53dae8c2..915bf74bfa 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java @@ -7,7 +7,6 @@ import org.junit.BeforeClass; import org.opensearch.knn.KNNTestCase; -import org.opensearch.knn.quantization.enums.QuantizationType; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.SQParams; import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; @@ -18,20 +17,9 @@ public class QuantizerRegistryTests extends KNNTestCase { @BeforeClass public static void setup() { - // Register the quantizers for testing with enums - QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE, ScalarQuantizationType.ONE_BIT, OneBitScalarQuantizer::new); - QuantizerRegistry.register( - SQParams.class, - QuantizationType.VALUE, - ScalarQuantizationType.TWO_BIT, - () -> new MultiBitScalarQuantizer(2) - ); - QuantizerRegistry.register( - SQParams.class, - QuantizationType.VALUE, - ScalarQuantizationType.FOUR_BIT, - () -> new MultiBitScalarQuantizer(4) - ); + QuantizerRegistry.register(new SQParams(ScalarQuantizationType.ONE_BIT).getTypeIdentifier(), OneBitScalarQuantizer::new); + QuantizerRegistry.register(new SQParams(ScalarQuantizationType.TWO_BIT).getTypeIdentifier(), () -> new MultiBitScalarQuantizer(2)); + QuantizerRegistry.register(new SQParams(ScalarQuantizationType.FOUR_BIT).getTypeIdentifier(), () -> new MultiBitScalarQuantizer(4)); } public void testRegisterAndGetQuantizer() { @@ -50,17 +38,4 @@ public void testRegisterAndGetQuantizer() { Quantizer fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); assertTrue(fourBitQuantizer instanceof MultiBitScalarQuantizer); } - - public void testGetQuantizer_withUnsupportedTypeIdentifier() { - // Create SQParams with an unsupported type identifier - SQParams params = new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE); // Assuming UNSUPPORTED_TYPE is not registered - - // Expect IllegalArgumentException when requesting a quantizer with unsupported params - IllegalArgumentException exception = assertThrows( - IllegalArgumentException.class, - () -> { QuantizerRegistry.getQuantizer(params); } - ); - - assertTrue(exception.getMessage().contains("No quantizer registered for type identifier")); - } } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java index ebe6bf6bd9..d08670383b 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java @@ -35,7 +35,9 @@ public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws byte[] serialized = state.toByteArray(); MultiBitScalarQuantizationState deserialized = MultiBitScalarQuantizationState.fromByteArray(serialized); - assertArrayEquals(thresholds, deserialized.getThresholds()); + for (int i = 0; i < thresholds.length; i++) { + assertArrayEquals(thresholds[i], deserialized.getThresholds()[i], 0.0f); + } assertEquals(params, deserialized.getQuantizationParams()); } } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java index 54e304732b..cbcfaeda35 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java @@ -9,10 +9,8 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.SQParams; -import org.opensearch.knn.quantization.models.quantizationState.DefaultQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; - import java.io.IOException; public class QuantizationStateTests extends KNNTestCase { @@ -23,12 +21,15 @@ public void testOneBitScalarQuantizationStateSerialization() throws IOException, OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + // Serialize byte[] serializedState = state.toByteArray(); + + // Deserialize OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState); - float delta = 0.0001f; + float delta = 0.0001f; assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta); - assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType()); + assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); } public void testMultiBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException { @@ -37,24 +38,16 @@ public void testMultiBitScalarQuantizationStateSerialization() throws IOExceptio MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + // Serialize byte[] serializedState = state.toByteArray(); + + // Deserialize MultiBitScalarQuantizationState deserializedState = MultiBitScalarQuantizationState.fromByteArray(serializedState); - float delta = 0.0001f; + float delta = 0.0001f; for (int i = 0; i < thresholds.length; i++) { assertArrayEquals(thresholds[i], deserializedState.getThresholds()[i], delta); } - assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType()); - } - - public void testDefaultQuantizationStateSerialization() throws IOException, ClassNotFoundException { - SQParams params = new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE); - - DefaultQuantizationState state = new DefaultQuantizationState(params); - - byte[] serializedState = state.toByteArray(); - DefaultQuantizationState deserializedState = DefaultQuantizationState.fromByteArray(serializedState); - - assertEquals(params.getQuantizationType(), deserializedState.getQuantizationParams().getQuantizationType()); + assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); } } diff --git a/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/BitPackingUtilsTests.java similarity index 94% rename from src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java rename to src/test/java/org/opensearch/knn/quantization/quantizer/BitPackingUtilsTests.java index c91c7177b1..3d115393b2 100644 --- a/src/test/java/org/opensearch/knn/quantization/util/BitPackingUtilsTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/BitPackingUtilsTests.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.knn.quantization.util; +package org.opensearch.knn.quantization.quantizer; import org.opensearch.knn.KNNTestCase; @@ -23,7 +23,7 @@ public void testPackBits() { public void testPackBitsEmptyList() { IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { BitPacker.packBits(Arrays.asList()); }); - assertEquals("The list of bit arrays cannot be empty.", exception.getMessage()); + assertEquals("The list of bit arrays cannot be null or empty.", exception.getMessage()); } public void testPackBitsNullBitArray() { diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java index 231da3dfe7..859e1238a9 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java @@ -7,12 +7,14 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; -import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.SQParams; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import java.io.IOException; + public class MultiBitScalarQuantizerTests extends KNNTestCase { public void testTrain_twoBit() { @@ -21,10 +23,8 @@ public void testTrain_twoBit() { { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } }; MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); - int[] sampledIndices = { 0, 1, 2 }; SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT); TrainingRequest request = new MockTrainingRequest(params, vectors); - request.setSampledIndices(sampledIndices); QuantizationState state = twoBitQuantizer.train(request); assertTrue(state instanceof MultiBitScalarQuantizationState); @@ -39,10 +39,8 @@ public void testTrain_fourBit() { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } }; - int[] sampledIndices = { 0, 1, 2 }; SQParams params = new SQParams(ScalarQuantizationType.FOUR_BIT); TrainingRequest request = new MockTrainingRequest(params, vectors); - request.setSampledIndices(sampledIndices); QuantizationState state = fourBitQuantizer.train(request); assertTrue(state instanceof MultiBitScalarQuantizationState); @@ -51,19 +49,19 @@ public void testTrain_fourBit() { assertEquals(4, mbState.getThresholds().length); // 4-bit quantization should have 4 thresholds } - public void testQuantize_twoBit() { + public void testQuantize_twoBit() throws IOException { MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f }; float[][] thresholds = { { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } }; SQParams params = new SQParams(ScalarQuantizationType.TWO_BIT); MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); - QuantizationOutput output = twoBitQuantizer.quantize(vector, state); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + twoBitQuantizer.quantize(vector, state, output); assertNotNull(output.getQuantizedVector()); - assertEquals(2, output.getQuantizedVector().length); } - public void testQuantize_fourBit() { + public void testQuantize_fourBit() throws IOException { MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4); float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f }; float[][] thresholds = { @@ -74,35 +72,30 @@ public void testQuantize_fourBit() { SQParams params = new SQParams(ScalarQuantizationType.FOUR_BIT); MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); - QuantizationOutput output = fourBitQuantizer.quantize(vector, state); - assertEquals(4, output.getQuantizedVector().length); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + fourBitQuantizer.quantize(vector, state, output); assertNotNull(output.getQuantizedVector()); } - public void testQuantize_withNullVector() { + public void testQuantize_withNullVector() throws IOException { MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); expectThrows( IllegalArgumentException.class, () -> twoBitQuantizer.quantize( null, - new MultiBitScalarQuantizationState(new SQParams(ScalarQuantizationType.TWO_BIT), new float[2][8]) + new MultiBitScalarQuantizationState(new SQParams(ScalarQuantizationType.TWO_BIT), new float[2][8]), + output ) ); } - public void testQuantize_withInvalidState() { - MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2); - float[] vector = { 1.3f, 2.2f, 3.3f, 4.1f, 5.6f, 6.7f, 7.4f, 8.1f }; - QuantizationState invalidState = new MockInvalidQuantizationState(); - expectThrows(IllegalArgumentException.class, () -> twoBitQuantizer.quantize(vector, invalidState)); - } - // Mock classes for testing private static class MockTrainingRequest extends TrainingRequest { private final float[][] vectors; public MockTrainingRequest(SQParams params, float[][] vectors) { - super(params, vectors.length); + super(vectors.length); this.vectors = vectors; } @@ -111,16 +104,4 @@ public float[] getVectorByDocId(int docId) { return vectors[docId]; } } - - private static class MockInvalidQuantizationState implements QuantizationState { - @Override - public SQParams getQuantizationParams() { - return new SQParams(ScalarQuantizationType.UNSUPPORTED_TYPE); - } - - @Override - public byte[] toByteArray() { - return new byte[0]; - } - } } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java index 12e8a43a82..bd4d60e0dd 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java @@ -7,14 +7,19 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; -import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.SQParams; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplerType; import org.opensearch.knn.quantization.sampler.SamplingFactory; -import org.opensearch.knn.quantization.util.QuantizerHelper; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; +import java.util.BitSet; public class OneBitScalarQuantizerTests extends KNNTestCase { @@ -22,7 +27,7 @@ public void testTrain_withTrainingRequired() { float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT); - TrainingRequest originalRequest = new TrainingRequest(params, vectors.length) { + TrainingRequest originalRequest = new TrainingRequest(vectors.length) { @Override public float[] getVectorByDocId(int docId) { return vectors[docId]; @@ -36,29 +41,31 @@ public float[] getVectorByDocId(int docId) { assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, mean, 0.001f); } - public void testQuantize_withState() { + public void testQuantize_withState() throws IOException { float[] vector = { 3.0f, 6.0f, 9.0f }; float[] thresholds = { 4.0f, 5.0f, 6.0f }; OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(ScalarQuantizationType.ONE_BIT), thresholds); OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); - QuantizationOutput output = quantizer.quantize(vector, state); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + quantizer.quantize(vector, state, output); assertNotNull(output); byte[] expectedPackedBits = new byte[] { 0b01100000 }; // or 96 in decimal assertArrayEquals(expectedPackedBits, output.getQuantizedVector()); } - public void testQuantize_withNullVector() { + public void testQuantize_withNullVector() throws IOException { OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); OneBitScalarQuantizationState state = new OneBitScalarQuantizationState( new SQParams(ScalarQuantizationType.ONE_BIT), new float[] { 0.0f } ); - expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state)); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(null, state, output)); } - public void testQuantize_withInvalidState() { + public void testQuantize_withInvalidState() throws IOException { OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); float[] vector = { 1.0f, 2.0f, 3.0f }; QuantizationState invalidState = new QuantizationState() { @@ -71,32 +78,43 @@ public SQParams getQuantizationParams() { public byte[] toByteArray() { return new byte[0]; } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + // no-op + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + // no-op + } }; - expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, invalidState)); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, invalidState, output)); } - public void testQuantize_withMismatchedDimensions() { + public void testQuantize_withMismatchedDimensions() throws IOException { OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); float[] vector = { 1.0f, 2.0f, 3.0f }; float[] thresholds = { 4.0f, 5.0f }; OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(new SQParams(ScalarQuantizationType.ONE_BIT), thresholds); - - expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state)); + BinaryQuantizationOutput output = new BinaryQuantizationOutput(); + expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state, output)); } public void testCalculateMean() { float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT); - TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) { + TrainingRequest samplingRequest = new TrainingRequest(vectors.length) { @Override public float[] getVectorByDocId(int docId) { return vectors[docId]; } }; - Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); - int[] sampledIndices = sampler.sample(vectors.length, 3); + Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); + BitSet sampledIndices = sampler.sample(vectors.length, 3); float[] mean = QuantizerHelper.calculateMean(samplingRequest, sampledIndices); assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, mean, 0.001f); } @@ -105,15 +123,15 @@ public void testCalculateMean_withNullVector() { float[][] vectors = { { 1.0f, 2.0f, 3.0f }, null, { 7.0f, 8.0f, 9.0f } }; SQParams params = new SQParams(ScalarQuantizationType.ONE_BIT); - TrainingRequest samplingRequest = new TrainingRequest(params, vectors.length) { + TrainingRequest samplingRequest = new TrainingRequest(vectors.length) { @Override public float[] getVectorByDocId(int docId) { return vectors[docId]; } }; - Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); - int[] sampledIndices = sampler.sample(vectors.length, 3); + Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); + BitSet sampledIndices = sampler.sample(vectors.length, 3); expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMean(samplingRequest, sampledIndices)); } } diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java index 4d33452890..e317bfb4b3 100644 --- a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java @@ -7,8 +7,7 @@ import org.opensearch.knn.KNNTestCase; -import java.util.Arrays; -import java.util.stream.IntStream; +import java.util.BitSet; public class ReservoirSamplerTests extends KNNTestCase { @@ -16,40 +15,20 @@ public void testSampleLessThanSampleSize() { ReservoirSampler sampler = new ReservoirSampler(); int totalNumberOfVectors = 5; int sampleSize = 10; - int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray(); - assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); + BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + BitSet expectedIndices = new BitSet(totalNumberOfVectors); + expectedIndices.set(0, totalNumberOfVectors); + assertEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); } public void testSampleEqualToSampleSize() { ReservoirSampler sampler = new ReservoirSampler(); int totalNumberOfVectors = 10; int sampleSize = 10; - int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray(); - assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); - } - - public void testSampleGreaterThanSampleSize() { - ReservoirSampler sampler = new ReservoirSampler(12345); // Fixed seed for reproducibility - int totalNumberOfVectors = 100; - int sampleSize = 10; - int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - assertEquals(sampleSize, sampledIndices.length); - assertTrue(Arrays.stream(sampledIndices).allMatch(i -> i >= 0 && i < totalNumberOfVectors)); - } - - public void testSampleReproducibility() { - long seed = 12345L; - ReservoirSampler sampler1 = new ReservoirSampler(seed); - ReservoirSampler sampler2 = new ReservoirSampler(seed); - int totalNumberOfVectors = 100; - int sampleSize = 10; - - int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); - int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); - - assertArrayEquals(sampledIndices1, sampledIndices2); + BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + BitSet expectedIndices = new BitSet(totalNumberOfVectors); + expectedIndices.set(0, totalNumberOfVectors); + assertEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); } public void testSampleRandomness() { @@ -58,25 +37,25 @@ public void testSampleRandomness() { int totalNumberOfVectors = 100; int sampleSize = 10; - int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); - int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); + BitSet sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); + BitSet sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); - assertNotEquals(Arrays.toString(sampledIndices1), Arrays.toString(sampledIndices2)); + assertNotEquals(sampledIndices1, sampledIndices2); } public void testEdgeCaseZeroVectors() { ReservoirSampler sampler = new ReservoirSampler(); int totalNumberOfVectors = 0; int sampleSize = 10; - int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - assertEquals(0, sampledIndices.length); + BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(0, sampledIndices.cardinality()); } public void testEdgeCaseZeroSampleSize() { ReservoirSampler sampler = new ReservoirSampler(); int totalNumberOfVectors = 10; int sampleSize = 0; - int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - assertEquals(0, sampledIndices.length); + BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals(0, sampledIndices.cardinality()); } } diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java index ca72c1c5e5..db8772b706 100644 --- a/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/sampler/SamplingFactoryTests.java @@ -9,7 +9,7 @@ public class SamplingFactoryTests extends KNNTestCase { public void testGetSampler_withReservoir() { - Sampler sampler = SamplingFactory.getSampler(SamplingFactory.SamplerType.RESERVOIR); + Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); assertTrue(sampler instanceof ReservoirSampler); }