diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java index dbf8e5bf91..df7d84d7d3 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationOutput/BinaryQuantizationOutput.java @@ -5,11 +5,8 @@ package org.opensearch.knn.quantization.models.quantizationOutput; -import lombok.NoArgsConstructor; import lombok.Getter; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; +import lombok.NoArgsConstructor; /** * The BinaryQuantizationOutput class represents the output of a quantization process in binary format. @@ -18,23 +15,27 @@ @NoArgsConstructor public class BinaryQuantizationOutput implements QuantizationOutput { @Getter - private final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + private byte[] quantizedVector; /** * Updates the quantized vector with a new byte array. * * @param newQuantizedVector the new quantized vector represented as a byte array. */ - public void updateQuantizedVector(final byte[] newQuantizedVector) throws IOException { + public void updateQuantizedVector(final byte[] newQuantizedVector) { if (newQuantizedVector == null || newQuantizedVector.length == 0) { throw new IllegalArgumentException("Quantized vector cannot be null or empty"); } - byteArrayOutputStream.reset(); - byteArrayOutputStream.write(newQuantizedVector); + this.quantizedVector = newQuantizedVector; } + /** + * Returns the quantized vector. + * + * @return the quantized vector byte array. + */ @Override public byte[] getQuantizedVector() { - return byteArrayOutputStream.toByteArray(); + return quantizedVector; } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java index 88b22c4fda..4f2ee36c5b 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/QuantizationParams.java @@ -5,7 +5,7 @@ package org.opensearch.knn.quantization.models.quantizationParams; -import java.io.Externalizable; +import org.opensearch.core.common.io.stream.Writeable; /** * Interface for quantization parameters. @@ -14,7 +14,7 @@ * Implementations of this interface are expected to provide specific configurations * for various quantization strategies. */ -public interface QuantizationParams extends Externalizable { +public interface QuantizationParams extends Writeable { /** * Provides a unique identifier for the quantization parameters. * This identifier is typically a combination of the quantization type diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java index c7c24062dd..5ca0840bbd 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationParams/ScalarQuantizationParams.java @@ -9,15 +9,15 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NoArgsConstructor; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; import java.util.Locale; /** - * The SQParams class represents the parameters specific to scalar quantization (SQ). + * The ScalarQuantizationParams class represents the parameters specific to scalar quantization (SQ). * This class implements the QuantizationParams interface and includes the type of scalar quantization. */ @Getter @@ -39,67 +39,40 @@ public static String generateTypeIdentifier(ScalarQuantizationType sqType) { } /** - * Serializes the SQParams object to an external output. - * This method writes the scalar quantization type to the output stream. + * Provides a unique type identifier for the ScalarQuantizationParams, combining the SQ type. + * This identifier is useful for distinguishing between different configurations of scalar quantization parameters. * - * @param out the ObjectOutput to write the object to. - * @throws IOException if an I/O error occurs during serialization. + * @return A string representing the unique type identifier. */ @Override - public void writeExternal(ObjectOutput out) throws IOException { - // The version is already written by the parent state class, no need to write it here again - // Retrieve the current version from VersionContext - // This context will be used by other classes involved in the serialization process. - // Example: - // int version = VersionContext.getVersion(); // Get the current version from VersionContext - // Any Version Specific logic can be wriiten based on Version - out.writeObject(sqType); + public String getTypeIdentifier() { + return generateIdentifier(sqType.getId()); + } + + private static String generateIdentifier(int id) { + return String.format(Locale.ROOT, "ScalarQuantizationParams_%d", id); } /** - * Deserializes the SQParams object from an external input with versioning. - * This method reads the scalar quantization type and new field from the input stream based on the version. + * Writes the object to the output stream. + * This method is part of the Writeable interface and is used to serialize the object. * - * @param in the ObjectInput to read the object from. - * @throws IOException if an I/O error occurs during deserialization. - * @throws ClassNotFoundException if the class of the serialized object cannot be found. + * @param out the output stream to write the object to. + * @throws IOException if an I/O error occurs. */ @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - // The version is already read by the parent state class and set in VersionContext - // Retrieve the current version from VersionContext to handle version-specific deserialization logic - // int versionId = VersionContext.getVersion(); - // Version version = Version.fromId(versionId); - - sqType = (ScalarQuantizationType) in.readObject(); - - // Add version-specific deserialization logic - // For example, if new fields are added in a future version, handle them here - // This section contains conditional logic to handle different versions appropriately. - // Example: - // if (version.onOrAfter(Version.V_1_0_0) && version.before(Version.V_2_0_0)) { - // // Handle logic for versions between 1.0.0 and 2.0.0 - // // Example: Read additional fields introduced in version 1.0.0 - // // newField = in.readInt(); - // } else if (version.onOrAfter(Version.V_2_0_0)) { - // // Handle logic for versions 2.0.0 and above - // // Example: Read additional fields introduced in version 2.0.0 - // // anotherNewField = in.readFloat(); - // } + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(sqType); } /** - * Provides a unique type identifier for the SQParams, combining the SQ type. - * This identifier is useful for distinguishing between different configurations of scalar quantization parameters. + * Reads the object from the input stream. + * This method is part of the Writeable interface and is used to deserialize the object. * - * @return A string representing the unique type identifier. + * @param in the input stream to read the object from. + * @throws IOException if an I/O error occurs. */ - @Override - public String getTypeIdentifier() { - return generateIdentifier(sqType.getId()); - } - - private static String generateIdentifier(int id) { - return String.format(Locale.ROOT, "ScalarQuantizationParams_%d", id); + public ScalarQuantizationParams(StreamInput in, int version) throws IOException { + this.sqType = in.readEnum(ScalarQuantizationType.class); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java index 3e3249c6ff..24bba902b8 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java @@ -9,12 +9,12 @@ import lombok.Getter; import lombok.NoArgsConstructor; import org.opensearch.Version; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; /** * DefaultQuantizationState is used as a fallback state when no training is required or if training fails. @@ -27,16 +27,22 @@ public class DefaultQuantizationState implements QuantizationState { private QuantizationParams params; private static final long serialVersionUID = 1L; // Version ID for serialization - /** - * Returns the quantization parameters associated with this state. - * - * @return the quantization parameters. - */ @Override public QuantizationParams getQuantizationParams() { return params; } + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(Version.CURRENT.id); // Write the version + params.writeTo(out); + } + + public DefaultQuantizationState(StreamInput in) throws IOException { + int version = in.readInt(); // Read the version + this.params = new ScalarQuantizationParams(in, version); + } + /** * Serializes the quantization state to a byte array. * @@ -45,7 +51,7 @@ public QuantizationParams getQuantizationParams() { */ @Override public byte[] toByteArray() throws IOException { - return QuantizationStateSerializer.serialize(this, null); + return QuantizationStateSerializer.serialize(this); } /** @@ -57,36 +63,6 @@ public byte[] toByteArray() throws IOException { * @throws ClassNotFoundException if the class of the serialized object cannot be found. */ public static DefaultQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { - return (DefaultQuantizationState) QuantizationStateSerializer.deserialize( - bytes, - new DefaultQuantizationState(), - (parentParams, specificData) -> new DefaultQuantizationState((ScalarQuantizationParams) parentParams) - ); - } - - /** - * Writes the object to the output stream. - * This method is part of the Externalizable interface and is used to serialize the object. - * - * @param out the output stream to write the object to. - * @throws IOException if an I/O error occurs. - */ - @Override - public void writeExternal(ObjectOutput out) throws IOException { - out.writeInt(Version.CURRENT.id); // Write the version - out.writeObject(params); - } - - /** - * Reads the object from the input stream. - * This method is part of the Externalizable interface and is used to deserialize the object. - * - * @param in the input stream to read the object from. - * @throws IOException if an I/O error occurs. - * @throws ClassNotFoundException if the class of the serialized object cannot be found. - */ - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - this.params = (QuantizationParams) in.readObject(); + return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(bytes, DefaultQuantizationState::new); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java index 095d245f23..ed868efb50 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java @@ -9,12 +9,11 @@ import lombok.Getter; import lombok.NoArgsConstructor; import org.opensearch.Version; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; -import org.opensearch.knn.quantization.util.VersionContext; import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; /** * MultiBitScalarQuantizationState represents the state of multi-bit scalar quantization, @@ -50,37 +49,31 @@ public ScalarQuantizationParams getQuantizationParams() { } /** - * This method is responsible for writing the state of the OneBitScalarQuantizationState object to an external output. + * This method is responsible for writing the state of the MultiBitScalarQuantizationState object to an external output. * It includes versioning information to ensure compatibility between different versions of the serialized object. * - *

Versioning is managed using the {@link VersionContext} class. This allows other classes that are serialized - * as part of the state to access the version information and implement version-specific logic if needed.

- * - *

The {@link VersionContext#setVersion(int)} method sets the version information in a thread-local variable, - * ensuring that the version is available to all classes involved in the serialization process within the current thread context.

- * *
      * {@code
      * // Example usage in the writeExternal method:
-     * VersionContext.setVersion(version);
      * out.writeInt(version); // Write the version
-     * quantizationParams.writeExternal(out);
-     * out.writeInt(meanThresholds.length);
-     * for (float mean : meanThresholds) {
-     *     out.writeFloat(mean);
+     * quantizationParams.writeTo(out);
+     * out.writeInt(thresholds.length);
+     * out.writeInt(thresholds[0].length);
+     * for (float[] row : thresholds) {
+     *     for (float value : row) {
+     *         out.writeFloat(value);
+     *     }
      * }
      * }
      * 
* - * @param out the ObjectOutput to write the object to. + * @param out the StreamOutput to write the object to. * @throws IOException if an I/O error occurs during serialization. */ @Override - public void writeExternal(ObjectOutput out) throws IOException { - int version = Version.CURRENT.id; - VersionContext.setVersion(version); - out.writeInt(version); // Write the version - quantizationParams.writeExternal(out); + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(Version.CURRENT.id); // Write the version + quantizationParams.writeTo(out); out.writeInt(thresholds.length); out.writeInt(thresholds[0].length); for (float[] row : thresholds) { @@ -91,40 +84,32 @@ public void writeExternal(ObjectOutput out) throws IOException { } /** - * This method is responsible for reading the state of the OneBitScalarQuantizationState object from an external input. + * This method is responsible for reading the state of the MultiBitScalarQuantizationState object from an external input. * It includes versioning information to ensure compatibility between different versions of the serialized object. * - *

The version information is read first, and then it is set using the {@link VersionContext#setVersion(int)} method. - * This makes the version information available to all classes involved in the deserialization process within the current thread context.

- * - *

Classes that are part of the deserialization process can retrieve the version information using the - * {@link VersionContext#getVersion()} method and implement version-specific logic accordingly.

- * *
      * {@code
      * // Example usage in the readExternal method:
      * int version = in.readInt(); // Read the version
-     * VersionContext.setVersion(version);
-     * quantizationParams = new ScalarQuantizationParams();
-     * quantizationParams.readExternal(in); // Use readExternal of SQParams
-     * int length = in.readInt();
-     * meanThresholds = new float[length];
-     * for (int i = 0; i < length; i++) {
-     *     meanThresholds[i] = in.readFloat();
+     * quantizationParams = new ScalarQuantizationParams(in, version);
+     * int rows = in.readInt();
+     * int cols = in.readInt();
+     * thresholds = new float[rows][cols];
+     * for (int i = 0; i < rows; i++) {
+     *     for (int j = 0; j < cols; j++) {
+     *         thresholds[i][j] = in.readFloat();
+     *     }
      * }
      * }
      * 
* - * @param in the ObjectInput to read the object from. + * @param in the StreamInput to read the object from. * @throws IOException if an I/O error occurs during deserialization. * @throws ClassNotFoundException if the class of the serialized object cannot be found. */ - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + public MultiBitScalarQuantizationState(StreamInput in) throws IOException { int version = in.readInt(); // Read the version - VersionContext.setVersion(version); - quantizationParams = new ScalarQuantizationParams(); - quantizationParams.readExternal(in); // Use readExternal of SQParams + this.quantizationParams = new ScalarQuantizationParams(in, version); int rows = in.readInt(); int cols = in.readInt(); thresholds = new float[rows][cols]; @@ -133,7 +118,6 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept thresholds[i][j] = in.readFloat(); } } - VersionContext.clear(); // Clear the version after use } /** @@ -155,7 +139,7 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept */ @Override public byte[] toByteArray() throws IOException { - return QuantizationStateSerializer.serialize(this, thresholds); + return QuantizationStateSerializer.serialize(this); } /** @@ -175,16 +159,8 @@ public byte[] toByteArray() throws IOException { * @param bytes the byte array containing the serialized state. * @return the deserialized MultiBitScalarQuantizationState object. * @throws IOException if an I/O error occurs during deserialization. - * @throws ClassNotFoundException if the class of a serialized object cannot be found. */ - public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { - return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize( - bytes, - new MultiBitScalarQuantizationState(), - (parentParams, thresholds) -> new MultiBitScalarQuantizationState( - (ScalarQuantizationParams) parentParams, - (float[][]) thresholds - ) - ); + public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException { + return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, MultiBitScalarQuantizationState::new); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java index 8ab37955e5..ff1aeedc16 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -9,12 +9,11 @@ import lombok.Getter; import lombok.NoArgsConstructor; import org.opensearch.Version; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; -import org.opensearch.knn.quantization.util.VersionContext; import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; /** * OneBitScalarQuantizationState represents the state of one-bit scalar quantization, @@ -49,18 +48,12 @@ public ScalarQuantizationParams getQuantizationParams() { * This method is responsible for writing the state of the OneBitScalarQuantizationState object to an external output. * It includes versioning information to ensure compatibility between different versions of the serialized object. * - *

Versioning is managed using the {@link VersionContext} class. This allows other classes that are serialized - * as part of the state to access the version information and implement version-specific logic if needed.

- * - *

The {@link VersionContext#setVersion(int)} method sets the version information in a thread-local variable, - * ensuring that the version is available to all classes involved in the serialization process within the current thread context.

* *
      * {@code
      * // Example usage in the writeExternal method:
-     * VersionContext.setVersion(version);
      * out.writeInt(version); // Write the version
-     * quantizationParams.writeExternal(out);
+     * quantizationParams.writeTo(out);
      * out.writeInt(meanThresholds.length);
      * for (float mean : meanThresholds) {
      *     out.writeFloat(mean);
@@ -68,39 +61,29 @@ public ScalarQuantizationParams getQuantizationParams() {
      * }
      * 
* - * @param out the ObjectOutput to write the object to. + * @param out the StreamOutput to write the object to. * @throws IOException if an I/O error occurs during serialization. */ @Override - public void writeExternal(ObjectOutput out) throws IOException { - int version = Version.CURRENT.id; - VersionContext.setVersion(version); - out.writeInt(version); // Write the version - quantizationParams.writeExternal(out); + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(Version.CURRENT.id); // Write the version + quantizationParams.writeTo(out); out.writeInt(meanThresholds.length); for (float mean : meanThresholds) { out.writeFloat(mean); } - VersionContext.clear(); // Clear the version after use } /** * This method is responsible for reading the state of the OneBitScalarQuantizationState object from an external input. * It includes versioning information to ensure compatibility between different versions of the serialized object. * - *

The version information is read first, and then it is set using the {@link VersionContext#setVersion(int)} method. - * This makes the version information available to all classes involved in the deserialization process within the current thread context.

- * - *

Classes that are part of the deserialization process can retrieve the version information using the - * {@link VersionContext#getVersion()} method and implement version-specific logic accordingly.

* *
      * {@code
      * // Example usage in the readExternal method:
      * int version = in.readInt(); // Read the version
-     * VersionContext.setVersion(version);
-     * quantizationParams = new ScalarQuantizationParams();
-     * quantizationParams.readExternal(in); // Use readExternal of SQParams
+     * quantizationParams = new ScalarQuantizationParams(in, version);
      * int length = in.readInt();
      * meanThresholds = new float[length];
      * for (int i = 0; i < length; i++) {
@@ -109,22 +92,18 @@ public void writeExternal(ObjectOutput out) throws IOException {
      * }
      * 
* - * @param in the ObjectInput to read the object from. + * @param in the StreamInput to read the object from. * @throws IOException if an I/O error occurs during deserialization. * @throws ClassNotFoundException if the class of the serialized object cannot be found. */ - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + public OneBitScalarQuantizationState(StreamInput in) throws IOException { int version = in.readInt(); // Read the version - VersionContext.setVersion(version); - quantizationParams = new ScalarQuantizationParams(); - quantizationParams.readExternal(in); // Use readExternal of SQParams + this.quantizationParams = new ScalarQuantizationParams(in, version); int length = in.readInt(); meanThresholds = new float[length]; for (int i = 0; i < length; i++) { meanThresholds[i] = in.readFloat(); } - VersionContext.clear(); // Clear the version after use } /** @@ -146,7 +125,7 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept */ @Override public byte[] toByteArray() throws IOException { - return QuantizationStateSerializer.serialize(this, meanThresholds); + return QuantizationStateSerializer.serialize(this); } /** @@ -166,16 +145,8 @@ public byte[] toByteArray() throws IOException { * @param bytes the byte array containing the serialized state. * @return the deserialized OneBitScalarQuantizationState object. * @throws IOException if an I/O error occurs during deserialization. - * @throws ClassNotFoundException if the class of a serialized object cannot be found. */ - public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { - return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize( - bytes, - new OneBitScalarQuantizationState(), - (parentParams, meanThresholds) -> new OneBitScalarQuantizationState( - (ScalarQuantizationParams) parentParams, - (float[]) meanThresholds - ) - ); + public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException { + return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, OneBitScalarQuantizationState::new); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java index d3778fe299..e32df8bc36 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java @@ -5,16 +5,16 @@ package org.opensearch.knn.quantization.models.quantizationState; +import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; -import java.io.Externalizable; import java.io.IOException; /** * QuantizationState interface represents the state of a quantization process, including the parameters used. * This interface provides methods for serializing and deserializing the state. */ -public interface QuantizationState extends Externalizable { +public interface QuantizationState extends Writeable { /** * Returns the quantization parameters associated with this state. * diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java index 5f00d8e0ca..1f378e0dc4 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationStateSerializer.java @@ -6,14 +6,10 @@ package org.opensearch.knn.quantization.models.quantizationState; import lombok.experimental.UtilityClass; -import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; -import java.io.ByteArrayOutputStream; -import java.io.ObjectOutputStream; -import java.io.Serializable; import java.io.IOException; -import java.io.ByteArrayInputStream; -import java.io.ObjectInputStream; /** * QuantizationStateSerializer is a utility class that provides methods for serializing and deserializing @@ -27,23 +23,20 @@ class QuantizationStateSerializer { */ @FunctionalInterface interface SerializableDeserializer { - QuantizationState deserialize(QuantizationParams parentParams, Serializable specificData); + QuantizationState deserialize(StreamInput in) throws IOException; } /** * Serializes the QuantizationState and specific data into a byte array. * * @param state The QuantizationState to serialize. - * @param specificData The specific data related to the state, to be serialized. * @return A byte array representing the serialized state and specific data. * @throws IOException If an I/O error occurs during serialization. */ - static byte[] serialize(QuantizationState state, Serializable specificData) throws IOException { - try (ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream out = new ObjectOutputStream(bos)) { - state.writeExternal(out); - out.writeObject(specificData); - out.flush(); - return bos.toByteArray(); + static byte[] serialize(QuantizationState state) throws IOException { + try (BytesStreamOutput out = new BytesStreamOutput()) { + state.writeTo(out); + return out.bytes().toBytesRef().bytes; } } @@ -51,18 +44,13 @@ static byte[] serialize(QuantizationState state, Serializable specificData) thro * Deserializes a QuantizationState and its specific data from a byte array. * * @param bytes The byte array containing the serialized data. - * @param stateInstance An instance of the state to call readExternal on. - * @param specificDataDeserializer The deserializer for the specific data associated with the state. + * @param deserializer The deserializer for the specific data associated with the state. * @return The deserialized QuantizationState including its specific data. * @throws IOException If an I/O error occurs during deserialization. - * @throws ClassNotFoundException If the class of the serialized object cannot be found. */ - static QuantizationState deserialize(byte[] bytes, QuantizationState stateInstance, SerializableDeserializer specificDataDeserializer) - throws IOException, ClassNotFoundException { - try (ByteArrayInputStream bis = new ByteArrayInputStream(bytes); ObjectInputStream in = new ObjectInputStream(bis)) { - stateInstance.readExternal(in); - Serializable specificData = (Serializable) in.readObject(); // Read the specific data - return specificDataDeserializer.deserialize(stateInstance.getQuantizationParams(), specificData); + static QuantizationState deserialize(byte[] bytes, SerializableDeserializer deserializer) throws IOException { + try (StreamInput in = StreamInput.wrap(bytes)) { + return deserializer.deserialize(in); } } } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java b/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java new file mode 100644 index 0000000000..b060b6e20e --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/BitPacker.java @@ -0,0 +1,141 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import lombok.experimental.UtilityClass; + +/** + * The BitPackingUtil class provides utility methods for packing bits into a byte array. + * This class is designed to be used by quantizers that need to convert floating-point vectors + * into compact binary representations by comparing them against quantization thresholds. + * + *

+ * The methods in this class handle both single-bit and multi-bit quantization scenarios, + * allowing for efficient storage and transmission of quantized vectors. + *

+ * + *

+ * This class is marked as a utility class using Lombok's {@link lombok.experimental.UtilityClass} annotation, + * making it a singleton and preventing instantiation. + *

+ */ +@UtilityClass +class BitPacker { + + /** + * Quantizes a given floating-point vector and packs the resulting quantized bits into a byte array. + * This method operates by comparing each element of the input vector against corresponding thresholds + * and encoding the results into a compact binary format using the specified number of bits per coordinate. + * + *

+ * The method supports multi-bit quantization where each coordinate of the input vector can be represented + * by multiple bits. For example, with 2-bit quantization, each coordinate is encoded into 2 bits, allowing + * for four distinct levels of quantization per coordinate. + *

+ * + *

+ * Example: + *

+ *

+ * Consider a vector with 3 coordinates: [1.2, 3.4, 5.6] and thresholds: + *

+ *
+     * thresholds = {
+     *     {1.0, 3.0, 5.0},  // First bit thresholds
+     *     {1.5, 3.5, 5.5}   // Second bit thresholds
+     * };
+     * 
+ *

+ * If the number of bits per coordinate is 2, the quantization process will proceed as follows: + *

+ * + *

+ * The resulting quantized bits will be 11 10 11, which is packed into a byte array + * with a size calculated to fit all bits. If there are fewer than 8 bits, the remaining bits + * in the byte are set to 0. + *

+ * + *

+ * Packing Process: + * The quantized bits are packed into the byte array. The first coordinate's bits are stored in the most + * significant positions of the first byte, followed by the second coordinate, and so on. In the example + * above, the resulting byte array will have the following binary representation: + *

+ *
+     * packedBits = [11011000] // Only the first 6 bits are used, and the last two are set to 0.
+     * 
+ * + * @param vector the floating-point vector to be quantized. + * @param thresholds a 2D array representing the quantization thresholds. The first dimension corresponds to the number of bits per coordinate, and the second dimension corresponds to the vector's length. + * @param bitsPerCoordinate the number of bits used per coordinate, determining the granularity of the quantization. + * @return a byte array containing the packed bits representing the quantized vector. + */ + byte[] quantizeAndPackBits(final float[] vector, final float[][] thresholds, final int bitsPerCoordinate) { + int vectorLength = vector.length; + int totalBits = bitsPerCoordinate * vectorLength; + // Calculate the number of bytes needed to store the totalBits. + int byteLength = (totalBits + 7) >> 3; // Equivalent to (totalBits + 7) / 8 + byte[] packedBits = new byte[byteLength]; + + for (int i = 0; i < bitsPerCoordinate; i++) { + for (int j = 0; j < vectorLength; j++) { + if (vector[j] > thresholds[i][j]) { + int bitPosition = i * vectorLength + j; + // Calculate the index of the byte in the packedBits array. + // Pseudo-code formula: byteIndex = floor(bitPosition / 8) + int byteIndex = bitPosition >> 3; // Equivalent to bitPosition / 8 + // Calculate the bit index within the byte. + // Pseudo-code formula: bitIndex = 7 - (bitPosition % 8) + int bitIndex = 7 - (bitPosition & 7); // Equivalent to 7 - (bitPosition % 8) + // Set the bit at the calculated position. + // Pseudo-code formula: packedBits[byteIndex] = packedBits[byteIndex] | (1 << bitIndex) + packedBits[byteIndex] |= (1 << bitIndex); // Set the bit at bitIndex + } + } + } + + return packedBits; + } + + + /** + * Overloaded method to pack bits for one-bit quantization. + * + * @param vector The vector to quantize. + * @param thresholds The thresholds for quantization. + * + * Example: + *
+     * If we have a vector [1.2, 3.4, 5.6] and thresholds mean per dimension [2.0, 3.0, 4.0],
+     * the quantization process will be:
+     * - 1.2 < 2.0, so the first bit is 0
+     * - 3.4 > 3.0, so the second bit is 1
+     * - 5.6 > 4.0, so the third bit is 1
+     *
+     * The quantized vector will be [0, 1, 1].
+     * 
+ * + * @return The packed bits as a byte array. + */ + byte[] quantizeAndPackBits(final float[] vector, final float[] thresholds) { + return quantizeAndPackBits(vector, new float[][] { thresholds }, 1); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java index c6d366e5b7..8035b5b0f8 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -17,11 +17,50 @@ import org.opensearch.knn.quantization.sampler.SamplingFactory; import java.io.IOException; -import java.util.BitSet; /** * MultiBitScalarQuantizer is responsible for quantizing vectors into multi-bit representations per dimension. - * It supports multiple bits per coordinate, allowing for finer granularity in quantization. + * Unlike the OneBitScalarQuantizer, which uses a single bit per dimension to represent whether a value is above + * or below a mean threshold, the MultiBitScalarQuantizer allows for multiple bits per dimension, enabling more + * granular and precise quantization. + * + *

+ * In a OneBitScalarQuantizer, each dimension of a vector is compared to a single threshold (the mean), and a single + * bit is used to indicate whether the value is above or below that threshold. This results in a very coarse + * representation where each dimension is either "on" or "off." + *

+ * + *

+ * The MultiBitScalarQuantizer, on the other hand, uses multiple thresholds per dimension. For example, in a 2-bit + * quantization scheme, three thresholds are used to divide each dimension into four possible regions. Each region + * is represented by a unique 2-bit value. This allows for a much finer representation of the data, capturing more + * nuances in the variation of each dimension. + *

+ * + *

+ * The thresholds in MultiBitScalarQuantizer are calculated based on the mean and standard deviation of the sampled + * vectors for each dimension. Here's how it works: + *

+ * + * + * + *

+ * The number of bits per coordinate is determined by the type of scalar quantization being applied, such as 2-bit + * or 4-bit quantization. The increased number of bits per coordinate in MultiBitScalarQuantizer allows for better + * preservation of information during the quantization process, making it more suitable for tasks where precision + * is crucial. However, this comes at the cost of increased storage and computational complexity compared to the + * simpler OneBitScalarQuantizer. + *

*/ public class MultiBitScalarQuantizer implements Quantizer { private final int bitsPerCoordinate; // Number of bits used to quantize each dimension @@ -69,19 +108,20 @@ public MultiBitScalarQuantizer(final int bitsPerCoordinate, final int samplingSi */ @Override public QuantizationState train(final TrainingRequest trainingRequest) { - BitSet sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); - int dimension = trainingRequest.getVectorByDocId(sampledIndices.nextSetBit(0)).length; + int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + int dimension = trainingRequest.getVectorByDocId(sampledIndices[0]).length; float[] meanArray = new float[dimension]; float[] stdDevArray = new float[dimension]; // Calculate sum, mean, and standard deviation in one pass - QuantizerHelper.calculateSumMeanAndStdDev(trainingRequest, sampledIndices, meanArray, stdDevArray); + QuantizerHelper.calculateMeanAndStdDev(trainingRequest, sampledIndices, meanArray, stdDevArray); float[][] thresholds = calculateThresholds(meanArray, stdDevArray, dimension); ScalarQuantizationParams params = (bitsPerCoordinate == 2) - ? new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT) - : new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + ? new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT) + : new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); return new MultiBitScalarQuantizationState(params, thresholds); } + /** * Quantizes the provided vector using the provided quantization state, producing a quantized output. * The vector is quantized based on the thresholds in the quantization state. @@ -102,21 +142,7 @@ public void quantize(final float[] vector, final QuantizationState state, final if (thresholds == null || thresholds[0].length != vector.length) { throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); } - // Directly pack bits without intermediate array - int totalBits = bitsPerCoordinate * vector.length; - int byteLength = (totalBits + 7) >> 3; // Calculate byte length needed - byte[] packedBits = new byte[byteLength]; - for (int i = 0; i < bitsPerCoordinate; i++) { - for (int j = 0; j < vector.length; j++) { - if (vector[j] > thresholds[i][j]) { - int bitPosition = i * vector.length + j; - int byteIndex = bitPosition >> 3; // Equivalent to bitPosition / 8 - int bitIndex = 7 - (bitPosition & 7); // Equivalent to 7 - (bitPosition % 8) - packedBits[byteIndex] |= (1 << bitIndex); // Set the bit - } - } - } - + byte[] packedBits = BitPacker.quantizeAndPackBits(vector, thresholds, bitsPerCoordinate); output.updateQuantizedVector(packedBits); } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java index eab3d992e2..09d34786d6 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -16,7 +16,6 @@ import org.opensearch.knn.quantization.sampler.SamplingFactory; import java.io.IOException; -import java.util.BitSet; /** * OneBitScalarQuantizer is responsible for quantizing vectors using a single bit per dimension. @@ -60,11 +59,12 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) { */ @Override public QuantizationState train(final TrainingRequest trainingRequest) { - BitSet sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + int[] sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(trainingRequest, sampledDocIds); return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds); } + /** * Quantizes the provided vector using the given quantization state. * It compares each dimension of the vector against the corresponding mean (threshold) to determine the quantized value. @@ -85,18 +85,8 @@ public void quantize(final float[] vector, final QuantizationState state, final if (thresholds == null || thresholds.length != vector.length) { throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector."); } - // Directly pack bits without intermediate array - int byteLength = (vector.length + 7) >> 3; // Calculate byte length needed - byte[] packedBits = new byte[byteLength]; - - for (int i = 0; i < vector.length; i++) { - if (vector[i] > thresholds[i]) { - int byteIndex = i >> 3; // Equivalent to i / 8 - int bitIndex = 7 - (i & 7); // Equivalent to 7 - (i % 8) - packedBits[byteIndex] |= (1 << bitIndex); // Set the bit - } - } - + // Use BitPackingUtil to pack bits for one-bit quantization + byte[] packedBits = BitPacker.quantizeAndPackBits(vector, thresholds); output.updateQuantizedVector(packedBits); } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java index e956326063..3a8ea0c8db 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java @@ -20,21 +20,16 @@ class QuantizerHelper { /** * Calculates the mean vector from a set of sampled vectors. * - *

This method takes a {@link TrainingRequest} object and an array of sampled indices, - * retrieves the vectors corresponding to these indices, and calculates the mean vector. - * Each element of the mean vector is computed as the average of the corresponding elements - * of the sampled vectors.

- * * @param samplingRequest The {@link TrainingRequest} containing the dataset and methods to access vectors by their indices. * @param sampledIndices An array of indices representing the sampled vectors to be used for mean calculation. * @return A float array representing the mean vector of the sampled vectors. * @throws IllegalArgumentException If any of the vectors at the sampled indices are null. * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors. */ - static float[] calculateMeanThresholds(TrainingRequest samplingRequest, BitSet sampledIndices) { - int totalSamples = sampledIndices.cardinality(); + static float[] calculateMeanThresholds(TrainingRequest samplingRequest, int[] sampledIndices) { + int totalSamples = sampledIndices.length; float[] mean = null; - for (int docId = sampledIndices.nextSetBit(0); docId >= 0; docId = sampledIndices.nextSetBit(docId + 1)) { + for (int docId : sampledIndices) { float[] vector = samplingRequest.getVectorByDocId(docId); if (vector == null) { throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); @@ -56,24 +51,22 @@ static float[] calculateMeanThresholds(TrainingRequest samplingRequest, } /** - * Calculates the sum, sum of squares, mean, and standard deviation for each dimension in a single pass. + * Calculates the mean and StdDev per dimension for sampled vectors. * * @param trainingRequest the request containing the data and parameters for training. * @param sampledIndices the indices of the sampled vectors. * @param meanArray the array to store the sum and then the mean of each dimension. * @param stdDevArray the array to store the sum of squares and then the standard deviation of each dimension. */ - static void calculateSumMeanAndStdDev( - TrainingRequest trainingRequest, - BitSet sampledIndices, - float[] meanArray, - float[] stdDevArray + static void calculateMeanAndStdDev( + TrainingRequest trainingRequest, + int[] sampledIndices, + float[] meanArray, + float[] stdDevArray ) { - int totalSamples = sampledIndices.cardinality(); + int totalSamples = sampledIndices.length; int dimension = meanArray.length; - - // Single pass to calculate sum and sum of squares - for (int docId = sampledIndices.nextSetBit(0); docId >= 0; docId = sampledIndices.nextSetBit(docId + 1)) { + for (int docId : sampledIndices) { float[] vector = trainingRequest.getVectorByDocId(docId); if (vector == null) { throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java index 22a7bae230..4322cedd94 100644 --- a/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java +++ b/src/main/java/org/opensearch/knn/quantization/sampler/ReservoirSampler.java @@ -7,8 +7,10 @@ import lombok.NoArgsConstructor; +import java.util.Arrays; import java.util.BitSet; import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.IntStream; /** * ReservoirSampler implements the Sampler interface and provides a method for sampling @@ -45,11 +47,9 @@ public static synchronized ReservoirSampler getInstance() { * @return an array of sampled indices. */ @Override - public BitSet sample(final int totalNumberOfVectors, final int sampleSize) { + public int[] sample(final int totalNumberOfVectors, final int sampleSize) { if (totalNumberOfVectors <= sampleSize) { - BitSet bitSet = new BitSet(totalNumberOfVectors); - bitSet.set(0, totalNumberOfVectors); - return bitSet; + return IntStream.range(0, totalNumberOfVectors).toArray(); } return reservoirSampleIndices(totalNumberOfVectors, sampleSize); } @@ -65,24 +65,27 @@ public BitSet sample(final int totalNumberOfVectors, final int sampleSize) { * * @param numVectors the total number of vectors. * @param sampleSize the number of indices to sample. - * @return a BitSet representing the sampled indices. + * @return an array of sampled indices. */ - private BitSet reservoirSampleIndices(final int numVectors, final int sampleSize) { + private int[] reservoirSampleIndices(final int numVectors, final int sampleSize) { int[] indices = new int[sampleSize]; + + // Initialize the reservoir with the first sampleSize elements for (int i = 0; i < sampleSize; i++) { indices[i] = i; } + + // Replace elements with gradually decreasing probability for (int i = sampleSize; i < numVectors; i++) { int j = ThreadLocalRandom.current().nextInt(i + 1); if (j < sampleSize) { indices[j] = i; } } - // Using BitSet to track the presence of indices - BitSet bitSet = new BitSet(numVectors); - for (int i = 0; i < sampleSize; i++) { - bitSet.set(indices[i]); - } - return bitSet; + + // Sort the sampled indices + Arrays.sort(indices); + + return indices; } } diff --git a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java index 17834cf043..828d878018 100644 --- a/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java +++ b/src/main/java/org/opensearch/knn/quantization/sampler/Sampler.java @@ -23,5 +23,5 @@ public interface Sampler { * @return an array of integers representing the indices of the sampled vectors. * @throws IllegalArgumentException if the sample size is greater than the total number of vectors. */ - BitSet sample(int totalNumberOfVectors, int sampleSize); + int[] sample(int totalNumberOfVectors, int sampleSize); } diff --git a/src/main/java/org/opensearch/knn/quantization/util/VersionContext.java b/src/main/java/org/opensearch/knn/quantization/util/VersionContext.java deleted file mode 100644 index 7746305abe..0000000000 --- a/src/main/java/org/opensearch/knn/quantization/util/VersionContext.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.quantization.util; - -import lombok.experimental.UtilityClass; - -/** - * Utility class to manage version information in a thread-safe manner using ThreadLocal storage. - * This class ensures that version information is available within the current thread context. - */ -@UtilityClass -public class VersionContext { - - /** - * ThreadLocal storage for version information. - * This allows each thread to have its own version information without interference. - */ - private final ThreadLocal versionHolder = new ThreadLocal<>(); - - /** - * Sets the version for the current thread. - * - * @param version the version to be set. - */ - public void setVersion(int version) { - versionHolder.set(version); - } - - /** - * Gets the version for the current thread. - * - * @return the version for the current thread. - */ - public int getVersion() { - return versionHolder.get(); - } - - /** - * Clears the version for the current thread. - */ - public void clear() { - versionHolder.remove(); - } -} diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java index 815f810710..99621a0e53 100644 --- a/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java +++ b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java @@ -7,6 +7,9 @@ import org.opensearch.knn.KNNTestCase; +import java.util.HashSet; +import java.util.Set; + public class ScalarQuantizationTypeTests extends KNNTestCase { public void testSQTypesValues() { ScalarQuantizationType[] expectedValues = { @@ -21,4 +24,12 @@ public void testSQTypesValueOf() { assertEquals(ScalarQuantizationType.TWO_BIT, ScalarQuantizationType.valueOf("TWO_BIT")); assertEquals(ScalarQuantizationType.FOUR_BIT, ScalarQuantizationType.valueOf("FOUR_BIT")); } + + public void testUniqueSQTypeValues() { + Set uniqueIds = new HashSet<>(); + for (ScalarQuantizationType type : ScalarQuantizationType.values()) { + boolean added = uniqueIds.add(type.getId()); + assertTrue("Duplicate value found: " + type.getId(), added); + } + } } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java index 974f586371..c6bcfb3a2d 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateSerializerTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.quantization.quantizationState; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; @@ -15,25 +16,33 @@ public class QuantizationStateSerializerTests extends KNNTestCase { - public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException, ClassNotFoundException { + public void testSerializeAndDeserializeOneBitScalarQuantizationState() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = new float[] { 0.1f, 0.2f, 0.3f }; OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + // Serialize byte[] serialized = state.toByteArray(); - OneBitScalarQuantizationState deserialized = OneBitScalarQuantizationState.fromByteArray(serialized); + + // Deserialize + StreamInput in = StreamInput.wrap(serialized); + OneBitScalarQuantizationState deserialized = new OneBitScalarQuantizationState(in); assertArrayEquals(mean, deserialized.getMeanThresholds(), 0.0f); assertEquals(params, deserialized.getQuantizationParams()); } - public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws IOException, ClassNotFoundException { + public void testSerializeAndDeserializeMultiBitScalarQuantizationState() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); float[][] thresholds = new float[][] { { 0.1f, 0.2f, 0.3f }, { 0.4f, 0.5f, 0.6f } }; MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + // Serialize byte[] serialized = state.toByteArray(); - MultiBitScalarQuantizationState deserialized = MultiBitScalarQuantizationState.fromByteArray(serialized); + + // Deserialize + StreamInput in = StreamInput.wrap(serialized); + MultiBitScalarQuantizationState deserialized = new MultiBitScalarQuantizationState(in); for (int i = 0; i < thresholds.length; i++) { assertArrayEquals(thresholds[i], deserialized.getThresholds()[i], 0.0f); diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java index 5a6b4b1dba..834440bcae 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java @@ -6,52 +6,44 @@ package org.opensearch.knn.quantization.quantizationState; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; -import org.opensearch.Version; -import org.opensearch.knn.quantization.util.VersionContext; - import java.io.IOException; public class QuantizationStateTests extends KNNTestCase { - public void testOneBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException { + public void testOneBitScalarQuantizationStateSerialization() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = { 1.0f, 2.0f, 3.0f }; OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); - // Set the version for serialization - VersionContext.setVersion(Version.CURRENT.id); - // Serialize byte[] serializedState = state.toByteArray(); // Deserialize - OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState); + StreamInput in = StreamInput.wrap(serializedState); + OneBitScalarQuantizationState deserializedState = new OneBitScalarQuantizationState(in); float delta = 0.0001f; assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta); assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); } - public void testMultiBitScalarQuantizationStateSerialization() throws IOException, ClassNotFoundException { + public void testMultiBitScalarQuantizationStateSerialization() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); float[][] thresholds = { { 0.5f, 1.5f, 2.5f }, { 1.0f, 2.0f, 3.0f } }; MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); - - // Set the version for serialization - VersionContext.setVersion(Version.CURRENT.id); - - // Serialize byte[] serializedState = state.toByteArray(); // Deserialize - MultiBitScalarQuantizationState deserializedState = MultiBitScalarQuantizationState.fromByteArray(serializedState); + StreamInput in = StreamInput.wrap(serializedState); + MultiBitScalarQuantizationState deserializedState = new MultiBitScalarQuantizationState(in); float delta = 0.0001f; for (int i = 0; i < thresholds.length; i++) { @@ -60,21 +52,14 @@ public void testMultiBitScalarQuantizationStateSerialization() throws IOExceptio assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); } - public void testSerializationWithDifferentVersions() throws IOException, ClassNotFoundException { + public void testSerializationWithDifferentVersions() throws IOException { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); float[] mean = { 1.0f, 2.0f, 3.0f }; OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); - - // Simulate an older version - VersionContext.setVersion(Version.V_2_0_0.id); - - // Serialize byte[] serializedState = state.toByteArray(); - - // Update to a new version and deserialize - VersionContext.setVersion(Version.CURRENT.id); - OneBitScalarQuantizationState deserializedState = OneBitScalarQuantizationState.fromByteArray(serializedState); + StreamInput in = StreamInput.wrap(serializedState); + OneBitScalarQuantizationState deserializedState = new OneBitScalarQuantizationState(in); float delta = 0.0001f; assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta); diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java index 8372ac5d23..6c81ae69c5 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.quantization.quantizer; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; @@ -17,8 +18,6 @@ import org.opensearch.knn.quantization.sampler.SamplingFactory; import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; import java.util.BitSet; public class OneBitScalarQuantizerTests extends KNNTestCase { @@ -83,13 +82,8 @@ public byte[] toByteArray() { } @Override - public void writeExternal(ObjectOutput out) throws IOException { - // no-op - } - - @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - // no-op + public void writeTo(StreamOutput out) throws IOException { + // Empty implementation for test } }; BinaryQuantizationOutput output = new BinaryQuantizationOutput(); @@ -120,7 +114,7 @@ public float[] getVectorByDocId(int docId) { }; Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); - BitSet sampledIndices = sampler.sample(vectors.length, 3); + int[] sampledIndices = sampler.sample(vectors.length, 3); float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices); assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, meanThresholds, 0.001f); } @@ -137,7 +131,7 @@ public float[] getVectorByDocId(int docId) { }; Sampler sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); - BitSet sampledIndices = sampler.sample(vectors.length, 3); + int[] sampledIndices = sampler.sample(vectors.length, 3); expectThrows(IllegalArgumentException.class, () -> QuantizerHelper.calculateMeanThresholds(samplingRequest, sampledIndices)); } } diff --git a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java index e930aef046..88d6689805 100644 --- a/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/sampler/ReservoirSamplerTests.java @@ -7,7 +7,8 @@ import org.opensearch.knn.KNNTestCase; -import java.util.BitSet; +import java.util.Arrays; +import java.util.stream.IntStream; public class ReservoirSamplerTests extends KNNTestCase { @@ -15,20 +16,18 @@ public void testSampleLessThanSampleSize() { ReservoirSampler sampler = ReservoirSampler.getInstance(); int totalNumberOfVectors = 5; int sampleSize = 10; - BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - BitSet expectedIndices = new BitSet(totalNumberOfVectors); - expectedIndices.set(0, totalNumberOfVectors); - assertEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray(); + assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); } public void testSampleEqualToSampleSize() { ReservoirSampler sampler = ReservoirSampler.getInstance(); int totalNumberOfVectors = 10; int sampleSize = 10; - BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - BitSet expectedIndices = new BitSet(totalNumberOfVectors); - expectedIndices.set(0, totalNumberOfVectors); - assertEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + int[] expectedIndices = IntStream.range(0, totalNumberOfVectors).toArray(); + assertArrayEquals("Sampled indices should include all available indices.", expectedIndices, sampledIndices); } public void testSampleRandomness() { @@ -37,25 +36,28 @@ public void testSampleRandomness() { int totalNumberOfVectors = 100; int sampleSize = 10; - BitSet sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); - BitSet sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); + int[] sampledIndices1 = sampler1.sample(totalNumberOfVectors, sampleSize); + int[] sampledIndices2 = sampler2.sample(totalNumberOfVectors, sampleSize); - assertNotEquals(sampledIndices1, sampledIndices2); + // It's unlikely but possible for the two samples to be equal, so we just check they are sorted correctly + Arrays.sort(sampledIndices1); + Arrays.sort(sampledIndices2); + assertFalse("Sampled indices should be different", Arrays.equals(sampledIndices1, sampledIndices2)); } public void testEdgeCaseZeroVectors() { ReservoirSampler sampler = ReservoirSampler.getInstance(); int totalNumberOfVectors = 0; int sampleSize = 10; - BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - assertEquals(0, sampledIndices.cardinality()); + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals("Sampled indices should be empty when there are zero vectors.", 0, sampledIndices.length); } public void testEdgeCaseZeroSampleSize() { ReservoirSampler sampler = ReservoirSampler.getInstance(); int totalNumberOfVectors = 10; int sampleSize = 0; - BitSet sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); - assertEquals(0, sampledIndices.cardinality()); + int[] sampledIndices = sampler.sample(totalNumberOfVectors, sampleSize); + assertEquals("Sampled indices should be empty when sample size is zero.", 0, sampledIndices.length); } -} +} \ No newline at end of file