Skip to content

Commit

Permalink
Quantization Framework Code Structure Improvement
Browse files Browse the repository at this point in the history
Signed-off-by: VIKASH TIWARI <[email protected]>
  • Loading branch information
Vikasht34 committed Aug 15, 2024
1 parent 88b1cc4 commit af650d4
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,53 +6,52 @@
package org.opensearch.knn.quantization.models.quantizationOutput;

import lombok.Getter;
import lombok.NoArgsConstructor;

import java.util.Arrays;
import lombok.RequiredArgsConstructor;

/**
* The BinaryQuantizationOutput class represents the output of a quantization process in binary format.
* It implements the QuantizationOutput interface to handle byte arrays specifically.
*/
@NoArgsConstructor
@Getter
@RequiredArgsConstructor
public class BinaryQuantizationOutput implements QuantizationOutput<byte[]> {
@Getter
private byte[] quantizedVector;
private final int bitsPerCoordinate;
private int currentVectorLength = -1; // Indicates uninitialized state

/**
* Prepares the quantized vector array based on the provided parameters and returns it for direct modification.
* This method ensures that the internal byte array is appropriately sized and cleared before being used.
* The method accepts two parameters:
* <ul>
* <li><b>bitsPerCoordinate:</b> The number of bits used per coordinate. This determines the granularity of the quantization.</li>
* <li><b>vectorLength:</b> The length of the original vector that needs to be quantized. This helps in calculating the required byte array size.</li>
* </ul>
* If the existing quantized vector is either null or not the same size as the required byte array,
* a new byte array is allocated. Otherwise, the existing array is cleared (i.e., all bytes are set to zero).
* This method is designed to be used in conjunction with a bit-packing utility that writes quantized values directly
* into the returned byte array.
* @param params an array of parameters, where the first parameter is the number of bits per coordinate (int),
* and the second parameter is the length of the vector (int).
* @return the prepared and writable quantized vector as a byte array.
* @throws IllegalArgumentException if the parameters are not as expected (e.g., missing or not integers).
* Prepares the quantized vector based on the vector length.
* This includes initializing or resetting the quantized vector.
*
* @param vectorLength The length of the vector to be quantized.
*/
@Override
public byte[] prepareAndGetWritableQuantizedVector(Object... params) {
if (params.length != 2 || !(params[0] instanceof Integer) || !(params[1] instanceof Integer)) {
throw new IllegalArgumentException("Expected two integer parameters: bitsPerCoordinate and vectorLength");
public void prepareQuantizedVector(int vectorLength) {
if (vectorLength <= 0) {
throw new IllegalArgumentException("Vector length must be greater than zero.");
}
int bitsPerCoordinate = (int) params[0];
int vectorLength = (int) params[1];
int totalBits = bitsPerCoordinate * vectorLength;
int byteLength = (totalBits + 7) >> 3;

if (this.quantizedVector == null || this.quantizedVector.length != byteLength) {
if (vectorLength != currentVectorLength) {
int totalBits = bitsPerCoordinate * vectorLength;
int byteLength = (totalBits + 7) >> 3;
this.quantizedVector = new byte[byteLength];
this.currentVectorLength = vectorLength;
} else {
Arrays.fill(this.quantizedVector, (byte) 0);
}
}

return this.quantizedVector;
/**
* Checks if the quantized vector has already been prepared for the given vector length.
*
* @param vectorLength The length of the vector to be quantized.
* @return true if the quantized vector is already prepared, false otherwise.
*/
@Override
public boolean isPrepared(int vectorLength) {
return vectorLength == currentVectorLength && quantizedVector != null;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,18 @@ public interface QuantizationOutput<T> {
T getQuantizedVector();

/**
* Prepares and returns the writable quantized vector for direct modification.
* Prepares the quantized vector based on the vector length.
* This includes initializing or resetting the quantized vector.
*
* @param params the parameters needed for preparing the quantized vector.
* @return the prepared and writable quantized vector.
* @param vectorLength The length of the vector to be quantized.
*/
T prepareAndGetWritableQuantizedVector(Object... params);
void prepareQuantizedVector(int vectorLength);

/**
* Checks if the quantized vector has already been prepared for the given vector length.
*
* @param vectorLength The length of the vector to be quantized.
* @return true if the quantized vector is already prepared, false otherwise.
*/
boolean isPrepared(int vectorLength);
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,14 @@ public void quantize(final float[] vector, final QuantizationState state, final
throw new IllegalArgumentException("Vector to quantize must not be null.");
}
validateState(state);
int vectorLength = vector.length;
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) state;
float[][] thresholds = multiBitState.getThresholds();
if (thresholds == null || thresholds[0].length != vector.length) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
// Prepare and get the writable array
byte[] writableArray = output.prepareAndGetWritableQuantizedVector(bitsPerCoordinate, vector.length);
BitPacker.quantizeAndPackBits(vector, thresholds, bitsPerCoordinate, writableArray);
if (!output.isPrepared(vectorLength)) output.prepareQuantizedVector(vectorLength);
BitPacker.quantizeAndPackBits(vector, thresholds, bitsPerCoordinate, output.getQuantizedVector());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ public void quantize(final float[] vector, final QuantizationState state, final
throw new IllegalArgumentException("Vector to quantize must not be null.");
}
validateState(state);
int vectorLength = vector.length;
OneBitScalarQuantizationState binaryState = (OneBitScalarQuantizationState) state;
float[] thresholds = binaryState.getMeanThresholds();
if (thresholds == null || thresholds.length != vector.length) {
if (thresholds == null || thresholds.length != vectorLength) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
// Prepare and get the writable array
byte[] writableArray = output.prepareAndGetWritableQuantizedVector(1, vector.length);
BitPacker.quantizeAndPackBits(vector, thresholds, writableArray);
if (!output.isPrepared(vectorLength)) output.prepareQuantizedVector(vectorLength);
BitPacker.quantizeAndPackBits(vector, thresholds, output.getQuantizedVector());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ public void testQuantize_twoBit() throws IOException {
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);

BinaryQuantizationOutput output = new BinaryQuantizationOutput();
BinaryQuantizationOutput output = new BinaryQuantizationOutput(2);
twoBitQuantizer.quantize(vector, state, output);

assertNotNull(output.getQuantizedVector());
}

Expand All @@ -72,14 +73,16 @@ public void testQuantize_fourBit() throws IOException {
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);

BinaryQuantizationOutput output = new BinaryQuantizationOutput();
BinaryQuantizationOutput output = new BinaryQuantizationOutput(4);
fourBitQuantizer.quantize(vector, state, output);

assertNotNull(output.getQuantizedVector());
}

public void testQuantize_withNullVector() throws IOException {
MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
BinaryQuantizationOutput output = new BinaryQuantizationOutput();
BinaryQuantizationOutput output = new BinaryQuantizationOutput(2);
output.prepareQuantizedVector(8); // Example length
expectThrows(
IllegalArgumentException.class,
() -> twoBitQuantizer.quantize(
Expand All @@ -90,6 +93,123 @@ public void testQuantize_withNullVector() throws IOException {
);
}

public void testQuantize_twoBit_multiple_times() throws IOException {
MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
float[] vector = { -2.5f, 1.5f, -0.5f, 4.0f, 6.5f, -3.5f, 0.0f, 7.2f };
float[][] thresholds = {
{ -3.0f, 1.0f, -1.0f, 3.5f, 5.0f, -4.0f, 0.5f, 7.0f },
{ -2.0f, 2.0f, 0.0f, 4.5f, 6.0f, -2.5f, -0.5f, 8.0f } };
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);

BinaryQuantizationOutput output = new BinaryQuantizationOutput(2);

// First quantization
twoBitQuantizer.quantize(vector, state, output);
assertNotNull(output.getQuantizedVector());

// Save the reference to the byte array
byte[] firstByteArray = output.getQuantizedVector();

// Expected output after the first quantization
byte[] expectedPackedBits = new byte[] { (byte) 0b11111101, (byte) 0b00001010 };

// Check the output value after the first quantization
assertArrayEquals(expectedPackedBits, output.getQuantizedVector());

// Modify vector for a second quantization call
vector = new float[] { -2.5f, 1.5f, -0.5f, 4.0f, 6.5f, -3.5f, 0.0f, 7.2f };

// Second quantization
twoBitQuantizer.quantize(vector, state, output);

// Assert that the same byte array reference is used
assertSame(firstByteArray, output.getQuantizedVector());

// Expected output after the second quantization (based on updated vector)
assertArrayEquals(expectedPackedBits, output.getQuantizedVector());
}

public void testQuantize_ReuseByteArray_forMultiBitQuantizer() throws IOException {
MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
float[] vector = { -2.5f, 1.5f, -0.5f, 4.0f, 6.5f, -3.5f, 0.0f, 7.2f };
float[][] thresholds = {
{ -3.0f, 1.0f, -1.0f, 3.5f, 5.0f, -4.0f, 0.5f, 7.0f },
{ -2.0f, 2.0f, 0.0f, 4.5f, 6.0f, -2.5f, -0.5f, 8.0f } };
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);

BinaryQuantizationOutput output = new BinaryQuantizationOutput(2);

// First quantization
twoBitQuantizer.quantize(vector, state, output);
byte[] firstByteArray = output.getQuantizedVector();

// Expected output after the first quantization
byte[] expectedPackedBits = new byte[] { (byte) 0b11111101, (byte) 0b00001010 };

// Check the output value after the first quantization
assertArrayEquals(expectedPackedBits, output.getQuantizedVector());

// Second quantization with the same vector length
twoBitQuantizer.quantize(vector, state, output);
byte[] secondByteArray = output.getQuantizedVector();

// Assert that the same byte array reference is used
assertSame(firstByteArray, secondByteArray);

// Check the output value after the second quantization
assertArrayEquals(expectedPackedBits, output.getQuantizedVector());

// Third quantization with the same vector length
twoBitQuantizer.quantize(vector, state, output);
byte[] thirdByteArray = output.getQuantizedVector();

// Assert that the same byte array reference is still used
assertSame(firstByteArray, thirdByteArray);

// Check the output value after the third quantization
assertArrayEquals(expectedPackedBits, output.getQuantizedVector());
}

public void testQuantize_withMultipleVectors_inLoop() throws IOException {
MultiBitScalarQuantizer twoBitQuantizer = new MultiBitScalarQuantizer(2);
float[][] vectors = {
{ -2.5f, 1.5f, -0.5f, 4.0f, 6.5f, -3.5f, 0.0f, 7.2f },
{ 2.0f, -1.0f, 3.5f, 0.0f, 5.5f, -2.5f, 1.5f, 6.0f },
{ -4.0f, 2.0f, -1.5f, 3.5f, -0.5f, 1.0f, 2.5f, -3.0f } };
float[][] thresholds = {
{ -3.0f, 1.0f, -1.0f, 3.5f, 5.0f, -4.0f, 0.5f, 7.0f },
{ -2.0f, 2.0f, 0.0f, 4.5f, 6.0f, -2.5f, -0.5f, 8.0f } };
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds);

BinaryQuantizationOutput output = new BinaryQuantizationOutput(2);

byte[] previousByteArray = null;
for (float[] vector : vectors) {
// Check if output is already prepared before quantization
boolean wasPrepared = output.isPrepared(vector.length);

// Prepare the output for the new vector length
output.prepareQuantizedVector(vector.length);

// Ensure that if it was prepared, it stays the same reference
if (wasPrepared) {
assertSame(previousByteArray, output.getQuantizedVector());
}

// Perform the quantization
twoBitQuantizer.quantize(vector, state, output);

// Save the reference to the byte array after quantization
previousByteArray = output.getQuantizedVector();

// Check that the output vector is correctly prepared
assertTrue(output.isPrepared(vector.length));
}
}

// Mock classes for testing
private static class MockTrainingRequest extends TrainingRequest<float[]> {
private final float[][] vectors;
Expand Down
Loading

0 comments on commit af650d4

Please sign in to comment.