Skip to content

Commit

Permalink
Quantization Framework Implementation with 1bit and MultiBit Binary Q…
Browse files Browse the repository at this point in the history
…uantizer

Signed-off-by: VIKASH TIWARI <[email protected]>
  • Loading branch information
Vikasht34 committed Aug 7, 2024
1 parent 4fda2c5 commit 20f785b
Show file tree
Hide file tree
Showing 39 changed files with 536 additions and 690 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920)
* Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931)
* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925)
* Quantization Framework For Disk Optimized Vector Search and Implementation of Binary 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)
* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,39 @@
package org.opensearch.knn.quantization.enums;

/**
* The SQTypes enum defines the various scalar quantization types that can be used
* in the KNN for vector quantization.
* Each type corresponds to a different bit-width representation of the quantized values.
* The ScalarQuantizationType enum defines the various scalar quantization types that can be used
* for vector quantization.
* Each type corresponds to a different bit and byte representation of the quantized values.
*/
public enum ScalarQuantizationType {
/**
* ONE_BIT quantization uses a single bit per coordinate.
* In the future , if you change the name , Please don't change value as
* serlization and deserlization depends on this
*/
ONE_BIT,
ONE_BIT(1),

/**
* TWO_BIT quantization uses two bits per coordinate.
* In the future , if you change the name , Please don't change value as
* serlization and deserlization depends on this
*/
TWO_BIT,
TWO_BIT(2),

/**
* FOUR_BIT quantization uses four bits per coordinate.
* In the future , if you change the name , Please don't change value as
* serlization and deserlization depends on this
*/
FOUR_BIT,
FOUR_BIT(4);

/**
* UNSUPPORTED_TYPE is used to denote quantization types that are not supported.
* This can be used as a placeholder or default value.
*/
UNSUPPORTED_TYPE
private final int id;

ScalarQuantizationType(int id) {
this.id = id;
}

public int getId() {
return id;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.quantization.factory;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.quantizer.Quantizer;

Expand All @@ -15,12 +17,10 @@
* based on the provided {@link QuantizationParams}. It uses a registry to look up the
* appropriate quantizer implementation for the given quantization parameters.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public final class QuantizerFactory {
private static final AtomicBoolean isRegistered = new AtomicBoolean(false);

// Private constructor to prevent instantiation
private QuantizerFactory() {}

/**
* Ensures that default quantizers are registered.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

package org.opensearch.knn.quantization.factory;

import org.opensearch.knn.quantization.enums.QuantizationType;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationParams.SQParams;
import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer;
Expand All @@ -15,34 +16,22 @@
* The QuantizerRegistrar class is responsible for registering default quantizers.
* This class ensures that the registration happens only once in a thread-safe manner.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
final class QuantizerRegistrar {

// Private constructor to prevent instantiation
private QuantizerRegistrar() {}

/**
* Registers default quantizers if not already registered.
* <p>
* This method is synchronized to ensure that registration occurs only once,
* even in a multi-threaded environment.
* </p>
*/
public static synchronized void registerDefaultQuantizers() {
static synchronized void registerDefaultQuantizers() {
// Register OneBitScalarQuantizer for SQParams with VALUE_QUANTIZATION and SQTypes.ONE_BIT
QuantizerRegistry.register(SQParams.class, QuantizationType.VALUE, ScalarQuantizationType.ONE_BIT, OneBitScalarQuantizer::new);
QuantizerRegistry.register(new SQParams(ScalarQuantizationType.ONE_BIT).getTypeIdentifier(), OneBitScalarQuantizer::new);
// Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 2
QuantizerRegistry.register(
SQParams.class,
QuantizationType.VALUE,
ScalarQuantizationType.TWO_BIT,
() -> new MultiBitScalarQuantizer(2)
);
QuantizerRegistry.register(new SQParams(ScalarQuantizationType.TWO_BIT).getTypeIdentifier(), () -> new MultiBitScalarQuantizer(2));
// Register MultiBitScalarQuantizer for SQParams with VALUE_QUANTIZATION with bit per co-ordinate = 4
QuantizerRegistry.register(
SQParams.class,
QuantizationType.VALUE,
ScalarQuantizationType.FOUR_BIT,
() -> new MultiBitScalarQuantizer(4)
);
QuantizerRegistry.register(new SQParams(ScalarQuantizationType.FOUR_BIT).getTypeIdentifier(), () -> new MultiBitScalarQuantizer(4));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

package org.opensearch.knn.quantization.factory;

import org.opensearch.knn.quantization.enums.QuantizationType;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.quantizer.Quantizer;

Expand All @@ -19,32 +19,20 @@
* of quantizer instances. Quantizers are registered with specific quantization parameters
* and type identifiers, allowing for efficient lookup and instantiation.
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
final class QuantizerRegistry {

// Private constructor to prevent instantiation
private QuantizerRegistry() {}

// ConcurrentHashMap for thread-safe access
private static final Map<String, Supplier<? extends Quantizer<?, ?>>> registry = new ConcurrentHashMap<>();

/**
* Registers a quantizer with the registry.
*
* @param paramClass the class of the quantization parameters
* @param quantizationType the quantization type (e.g., VALUE_QUANTIZATION)
* @param sqType the specific quantization subtype (e.g., ONE_BIT, TWO_BIT)
* @param paramIdentifier the unique identifier for the quantization parameters
* @param quantizerSupplier a supplier that provides instances of the quantizer
* @param <P> the type of quantization parameters
*/
public static <P extends QuantizationParams> void register(
final Class<P> paramClass,
final QuantizationType quantizationType,
final ScalarQuantizationType sqType,
final Supplier<? extends Quantizer<?, ?>> quantizerSupplier
) {
String identifier = createIdentifier(quantizationType, sqType);
public static void register(final String paramIdentifier, final Supplier<? extends Quantizer<?, ?>> quantizerSupplier) {
// Ensure that the quantizer for this identifier is registered only once
registry.computeIfAbsent(identifier, key -> quantizerSupplier);
registry.computeIfAbsent(paramIdentifier, key -> quantizerSupplier);
}

/**
Expand All @@ -60,23 +48,10 @@ public static <P extends QuantizationParams, Q> Quantizer<P, Q> getQuantizer(fin
String identifier = params.getTypeIdentifier();
Supplier<? extends Quantizer<?, ?>> supplier = registry.get(identifier);
if (supplier == null) {
throw new IllegalArgumentException(
"No quantizer registered for type identifier: " + identifier + ". Available quantizers: " + registry.keySet()
);
throw new IllegalArgumentException("No quantizer registered for type identifier: " + identifier);
}
@SuppressWarnings("unchecked")
Quantizer<P, Q> quantizer = (Quantizer<P, Q>) supplier.get();
return quantizer;
}

/**
* Creates a unique identifier for the quantizer based on the quantization type and specific quantization subtype.
*
* @param quantizationType the quantization type
* @param sqType the specific quantization subtype
* @return a string identifier
*/
private static String createIdentifier(final QuantizationType quantizationType, final ScalarQuantizationType sqType) {
return quantizationType.name() + "_" + sqType.name();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,38 @@

package org.opensearch.knn.quantization.models.quantizationOutput;

import java.io.ByteArrayOutputStream;
import java.io.IOException;

/**
* The BinaryQuantizationOutput class represents the output of a quantization process in binary format.
* It implements the QuantizationOutput interface to handle byte arrays specifically.
*/
public class BinaryQuantizationOutput implements QuantizationOutput<byte[]> {
private final byte[] quantizedVector;
private final ByteArrayOutputStream byteArrayOutputStream;

/**
* Constructs a BinaryQuantizationOutput instance with a default initial buffer size.
*/
public BinaryQuantizationOutput() {
this.byteArrayOutputStream = new ByteArrayOutputStream();
}

/**
* Constructs a BinaryQuantizationOutput instance with the specified quantized vector.
* Updates the quantized vector with a new byte array.
*
* @param quantizedVector the quantized vector represented as a byte array.
* @param newQuantizedVector the new quantized vector represented as a byte array.
*/
public BinaryQuantizationOutput(final byte[] quantizedVector) {
if (quantizedVector == null) {
throw new IllegalArgumentException("Quantized vector cannot be null");
public void updateQuantizedVector(final byte[] newQuantizedVector) throws IOException {
if (newQuantizedVector == null || newQuantizedVector.length == 0) {
throw new IllegalArgumentException("Quantized vector cannot be null or empty");
}
this.quantizedVector = quantizedVector;
byteArrayOutputStream.reset();
byteArrayOutputStream.write(newQuantizedVector);
}

@Override
public byte[] getQuantizedVector() {
return quantizedVector;
return byteArrayOutputStream.toByteArray();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.quantization.models.quantizationOutput;

import java.io.IOException;

/**
* The QuantizationOutput interface defines the contract for quantization output data.
*
Expand All @@ -17,4 +19,12 @@ public interface QuantizationOutput<T> {
* @return the quantized data.
*/
T getQuantizedVector();

/**
* Updates the quantized vector with new data.
*
* @param newQuantizedVector the new quantized vector data.
* @throws IOException if an I/O error occurs during the update.
*/
void updateQuantizedVector(T newQuantizedVector) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

package org.opensearch.knn.quantization.models.quantizationParams;

import org.opensearch.knn.quantization.enums.QuantizationType;

import java.io.Serializable;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;

/**
* Interface for quantization parameters.
Expand All @@ -16,17 +17,7 @@
* Implementations of this interface are expected to provide specific configurations
* for various quantization strategies.
*/
public interface QuantizationParams extends Serializable {

/**
* Gets the quantization type associated with the parameters.
* The quantization type defines the overall strategy or method used
* for quantization, such as VALUE_QUANTIZATION or SPACE_QUANTIZATION.
*
* @return the {@link QuantizationType} indicating the quantization method.
*/
QuantizationType getQuantizationType();

public interface QuantizationParams extends Externalizable {
/**
* Provides a unique identifier for the quantization parameters.
* This identifier is typically a combination of the quantization type
Expand All @@ -36,4 +27,29 @@ public interface QuantizationParams extends Serializable {
* @return a string representing the unique type identifier.
*/
String getTypeIdentifier();

/**
* Serializes the QuantizationParams object to an external output.
* Default implementation is no-op.
*
* @param out the ObjectOutput to write the object to.
* @throws IOException if an I/O error occurs during serialization.
*/
@Override
default void writeExternal(ObjectOutput out) throws IOException {
// Default no-op implementation
}

/**
* Deserializes the QuantizationParams object from an external input.
* Default implementation is no-op.
*
* @param in the ObjectInput to read the object from.
* @throws IOException if an I/O error occurs during deserialization.
* @throws ClassNotFoundException if the class of the serialized object cannot be found.
*/
@Override
default void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
// Default no-op implementation
}
}
Loading

0 comments on commit 20f785b

Please sign in to comment.