Skip to content

Commit

Permalink
Implemented Serlization using Writable
Browse files Browse the repository at this point in the history
Signed-off-by: VIKASH TIWARI <[email protected]>
  • Loading branch information
Vikasht34 committed Aug 8, 2024
1 parent 0ebbef0 commit 6e34b68
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 296 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.quantization.models.quantizationParams;

import java.io.Externalizable;
import org.opensearch.core.common.io.stream.Writeable;

/**
* Interface for quantization parameters.
Expand All @@ -14,7 +14,7 @@
* Implementations of this interface are expected to provide specific configurations
* for various quantization strategies.
*/
public interface QuantizationParams extends Externalizable {
public interface QuantizationParams extends Writeable {
/**
* Provides a unique identifier for the quantization parameters.
* This identifier is typically a combination of the quantization type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Locale;

/**
* The SQParams class represents the parameters specific to scalar quantization (SQ).
* The ScalarQuantizationParams class represents the parameters specific to scalar quantization (SQ).
* This class implements the QuantizationParams interface and includes the type of scalar quantization.
*/
@Getter
Expand All @@ -39,67 +39,40 @@ public static String generateTypeIdentifier(ScalarQuantizationType sqType) {
}

/**
* Serializes the SQParams object to an external output.
* This method writes the scalar quantization type to the output stream.
* Provides a unique type identifier for the ScalarQuantizationParams, combining the SQ type.
* This identifier is useful for distinguishing between different configurations of scalar quantization parameters.
*
* @param out the ObjectOutput to write the object to.
* @throws IOException if an I/O error occurs during serialization.
* @return A string representing the unique type identifier.
*/
@Override
public void writeExternal(ObjectOutput out) throws IOException {
// The version is already written by the parent state class, no need to write it here again
// Retrieve the current version from VersionContext
// This context will be used by other classes involved in the serialization process.
// Example:
// int version = VersionContext.getVersion(); // Get the current version from VersionContext
// Any Version Specific logic can be wriiten based on Version
out.writeObject(sqType);
public String getTypeIdentifier() {
return generateIdentifier(sqType.getId());
}

private static String generateIdentifier(int id) {
return String.format(Locale.ROOT, "ScalarQuantizationParams_%d", id);
}

/**
* Deserializes the SQParams object from an external input with versioning.
* This method reads the scalar quantization type and new field from the input stream based on the version.
* Writes the object to the output stream.
* This method is part of the Writeable interface and is used to serialize the object.
*
* @param in the ObjectInput to read the object from.
* @throws IOException if an I/O error occurs during deserialization.
* @throws ClassNotFoundException if the class of the serialized object cannot be found.
* @param out the output stream to write the object to.
* @throws IOException if an I/O error occurs.
*/
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
// The version is already read by the parent state class and set in VersionContext
// Retrieve the current version from VersionContext to handle version-specific deserialization logic
// int versionId = VersionContext.getVersion();
// Version version = Version.fromId(versionId);

sqType = (ScalarQuantizationType) in.readObject();

// Add version-specific deserialization logic
// For example, if new fields are added in a future version, handle them here
// This section contains conditional logic to handle different versions appropriately.
// Example:
// if (version.onOrAfter(Version.V_1_0_0) && version.before(Version.V_2_0_0)) {
// // Handle logic for versions between 1.0.0 and 2.0.0
// // Example: Read additional fields introduced in version 1.0.0
// // newField = in.readInt();
// } else if (version.onOrAfter(Version.V_2_0_0)) {
// // Handle logic for versions 2.0.0 and above
// // Example: Read additional fields introduced in version 2.0.0
// // anotherNewField = in.readFloat();
// }
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(sqType);
}

/**
* Provides a unique type identifier for the SQParams, combining the SQ type.
* This identifier is useful for distinguishing between different configurations of scalar quantization parameters.
* Reads the object from the input stream.
* This method is part of the Writeable interface and is used to deserialize the object.
*
* @return A string representing the unique type identifier.
* @param in the input stream to read the object from.
* @throws IOException if an I/O error occurs.
*/
@Override
public String getTypeIdentifier() {
return generateIdentifier(sqType.getId());
}

private static String generateIdentifier(int id) {
return String.format(Locale.ROOT, "ScalarQuantizationParams_%d", id);
public ScalarQuantizationParams(StreamInput in, int version) throws IOException {
this.sqType = in.readEnum(ScalarQuantizationType.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;

/**
* DefaultQuantizationState is used as a fallback state when no training is required or if training fails.
Expand All @@ -27,16 +27,22 @@ public class DefaultQuantizationState implements QuantizationState {
private QuantizationParams params;
private static final long serialVersionUID = 1L; // Version ID for serialization

/**
* Returns the quantization parameters associated with this state.
*
* @return the quantization parameters.
*/
@Override
public QuantizationParams getQuantizationParams() {
return params;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(Version.CURRENT.id); // Write the version
params.writeTo(out);
}

public DefaultQuantizationState(StreamInput in) throws IOException {
int version = in.readInt(); // Read the version
this.params = new ScalarQuantizationParams(in, version);
}

/**
* Serializes the quantization state to a byte array.
*
Expand All @@ -45,7 +51,7 @@ public QuantizationParams getQuantizationParams() {
*/
@Override
public byte[] toByteArray() throws IOException {
return QuantizationStateSerializer.serialize(this, null);
return QuantizationStateSerializer.serialize(this);
}

/**
Expand All @@ -57,36 +63,6 @@ public byte[] toByteArray() throws IOException {
* @throws ClassNotFoundException if the class of the serialized object cannot be found.
*/
public static DefaultQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(
bytes,
new DefaultQuantizationState(),
(parentParams, specificData) -> new DefaultQuantizationState((ScalarQuantizationParams) parentParams)
);
}

/**
* Writes the object to the output stream.
* This method is part of the Externalizable interface and is used to serialize the object.
*
* @param out the output stream to write the object to.
* @throws IOException if an I/O error occurs.
*/
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeInt(Version.CURRENT.id); // Write the version
out.writeObject(params);
}

/**
* Reads the object from the input stream.
* This method is part of the Externalizable interface and is used to deserialize the object.
*
* @param in the input stream to read the object from.
* @throws IOException if an I/O error occurs.
* @throws ClassNotFoundException if the class of the serialized object cannot be found.
*/
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this.params = (QuantizationParams) in.readObject();
return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(bytes, DefaultQuantizationState::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.util.VersionContext;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;

/**
* MultiBitScalarQuantizationState represents the state of multi-bit scalar quantization,
Expand Down Expand Up @@ -50,37 +49,31 @@ public ScalarQuantizationParams getQuantizationParams() {
}

/**
* This method is responsible for writing the state of the OneBitScalarQuantizationState object to an external output.
* This method is responsible for writing the state of the MultiBitScalarQuantizationState object to an external output.
* It includes versioning information to ensure compatibility between different versions of the serialized object.
*
* <p>Versioning is managed using the {@link VersionContext} class. This allows other classes that are serialized
* as part of the state to access the version information and implement version-specific logic if needed.</p>
*
* <p>The {@link VersionContext#setVersion(int)} method sets the version information in a thread-local variable,
* ensuring that the version is available to all classes involved in the serialization process within the current thread context.</p>
*
* <pre>
* {@code
* // Example usage in the writeExternal method:
* VersionContext.setVersion(version);
* out.writeInt(version); // Write the version
* quantizationParams.writeExternal(out);
* out.writeInt(meanThresholds.length);
* for (float mean : meanThresholds) {
* out.writeFloat(mean);
* quantizationParams.writeTo(out);
* out.writeInt(thresholds.length);
* out.writeInt(thresholds[0].length);
* for (float[] row : thresholds) {
* for (float value : row) {
* out.writeFloat(value);
* }
* }
* }
* </pre>
*
* @param out the ObjectOutput to write the object to.
* @param out the StreamOutput to write the object to.
* @throws IOException if an I/O error occurs during serialization.
*/
@Override
public void writeExternal(ObjectOutput out) throws IOException {
int version = Version.CURRENT.id;
VersionContext.setVersion(version);
out.writeInt(version); // Write the version
quantizationParams.writeExternal(out);
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(Version.CURRENT.id); // Write the version
quantizationParams.writeTo(out);
out.writeInt(thresholds.length);
out.writeInt(thresholds[0].length);
for (float[] row : thresholds) {
Expand All @@ -91,40 +84,32 @@ public void writeExternal(ObjectOutput out) throws IOException {
}

/**
* This method is responsible for reading the state of the OneBitScalarQuantizationState object from an external input.
* This method is responsible for reading the state of the MultiBitScalarQuantizationState object from an external input.
* It includes versioning information to ensure compatibility between different versions of the serialized object.
*
* <p>The version information is read first, and then it is set using the {@link VersionContext#setVersion(int)} method.
* This makes the version information available to all classes involved in the deserialization process within the current thread context.</p>
*
* <p>Classes that are part of the deserialization process can retrieve the version information using the
* {@link VersionContext#getVersion()} method and implement version-specific logic accordingly.</p>
*
* <pre>
* {@code
* // Example usage in the readExternal method:
* int version = in.readInt(); // Read the version
* VersionContext.setVersion(version);
* quantizationParams = new ScalarQuantizationParams();
* quantizationParams.readExternal(in); // Use readExternal of SQParams
* int length = in.readInt();
* meanThresholds = new float[length];
* for (int i = 0; i < length; i++) {
* meanThresholds[i] = in.readFloat();
* quantizationParams = new ScalarQuantizationParams(in);
* int rows = in.readInt();
* int cols = in.readInt();
* thresholds = new float[rows][cols];
* for (int i = 0; i < rows; i++) {
* for (int j = 0; j < cols; j++) {
* thresholds[i][j] = in.readFloat();
* }
* }
* }
* </pre>
*
* @param in the ObjectInput to read the object from.
* @param in the StreamInput to read the object from.
* @throws IOException if an I/O error occurs during deserialization.
* @throws ClassNotFoundException if the class of the serialized object cannot be found.
*/
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
public MultiBitScalarQuantizationState(StreamInput in) throws IOException {
int version = in.readInt(); // Read the version
VersionContext.setVersion(version);
quantizationParams = new ScalarQuantizationParams();
quantizationParams.readExternal(in); // Use readExternal of SQParams
this.quantizationParams = new ScalarQuantizationParams(in, version);
int rows = in.readInt();
int cols = in.readInt();
thresholds = new float[rows][cols];
Expand All @@ -133,7 +118,6 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept
thresholds[i][j] = in.readFloat();
}
}
VersionContext.clear(); // Clear the version after use
}

/**
Expand All @@ -155,7 +139,7 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept
*/
@Override
public byte[] toByteArray() throws IOException {
return QuantizationStateSerializer.serialize(this, thresholds);
return QuantizationStateSerializer.serialize(this);
}

/**
Expand All @@ -175,16 +159,8 @@ public byte[] toByteArray() throws IOException {
* @param bytes the byte array containing the serialized state.
* @return the deserialized MultiBitScalarQuantizationState object.
* @throws IOException if an I/O error occurs during deserialization.
* @throws ClassNotFoundException if the class of a serialized object cannot be found.
*/
public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException {
return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize(
bytes,
new MultiBitScalarQuantizationState(),
(parentParams, thresholds) -> new MultiBitScalarQuantizationState(
(ScalarQuantizationParams) parentParams,
(float[][]) thresholds
)
);
public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException {
return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, MultiBitScalarQuantizationState::new);
}
}
Loading

0 comments on commit 6e34b68

Please sign in to comment.