Skip to content

Commit

Permalink
Align dimensions to the nearest multiple of 8 in QuantizationState
Browse files Browse the repository at this point in the history
Signed-off-by: VIKASH TIWARI <[email protected]>
  • Loading branch information
Vikasht34 committed Aug 28, 2024
1 parent bf38c2e commit 76cbbbc
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.apache.lucene.index.Sorter;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.indices.Model;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import lombok.experimental.UtilityClass;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.quantizationService;
package org.opensearch.knn.index.quantizationservice;

import lombok.extern.log4j.Log4j2;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
Expand Down Expand Up @@ -49,7 +49,7 @@ public T getVectorAtThePosition(int position) throws IOException {
}
knnVectorValues.nextDoc();
}
// Return the vector and the updated index
// Return the vector
return knnVectorValues.getVector();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.quantizationService;
package org.opensearch.knn.index.quantizationservice;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,14 @@ public int getDimensions() {
if (thresholds == null || thresholds.length == 0 || thresholds[0] == null) {
throw new IllegalStateException("Error in getting Dimension: The thresholds array is not initialized.");
}
return thresholds.length * thresholds[0].length;
int originalDimensions = thresholds[0].length;

// Align the original dimensions to the next multiple of 8 for each bit level
int alignedDimensions = (originalDimensions + 7) & ~7;

// The final dimension count should consider the bit levels
return thresholds.length * alignedDimensions;

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ public int getBytesPerVector() {
@Override
public int getDimensions() {
// For one-bit quantization, the dimension for indexing is just the length of the thresholds array.
return meanThresholds.length;
// Align the original dimensions to the next multiple of 8
return (meanThresholds.length + 7) & ~7;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.quantizationService.QuantizationService;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.quantizationService;
package org.opensearch.knn.index.quantizationservice;

import org.opensearch.knn.KNNTestCase;
import org.junit.Before;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,78 @@ public void testOneBitScalarQuantizationStateRamBytesUsed() throws IOException {
assertEquals(expectedRamBytesUsed, actualRamBytesUsed);
}

public void testMultiBitScalarQuantizationStateGetDimensionsWithDimensionNotMultipleOf8() {
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);

// Case 1: 3 thresholds, each with 2 dimensions
float[][] thresholds1 = { { 0.5f, 1.5f }, { 1.0f, 2.0f }, { 1.5f, 2.5f } };
MultiBitScalarQuantizationState state1 = new MultiBitScalarQuantizationState(params, thresholds1);
int expectedDimensions1 = 24; // The next multiple of 8 considering all bits
assertEquals(expectedDimensions1, state1.getDimensions());

// Case 2: 1 threshold, with 5 dimensions (5 bits, should align to 8)
float[][] thresholds2 = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f } };
MultiBitScalarQuantizationState state2 = new MultiBitScalarQuantizationState(params, thresholds2);
int expectedDimensions2 = 8; // The next multiple of 8 considering all bits
assertEquals(expectedDimensions2, state2.getDimensions());

// Case 3: 4 thresholds, each with 7 dimensions (28 bits, should align to 32)
float[][] thresholds3 = {
{ 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f },
{ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f },
{ 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f },
{ 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } };
MultiBitScalarQuantizationState state3 = new MultiBitScalarQuantizationState(params, thresholds3);
int expectedDimensions3 = 32; // The next multiple of 8 considering all bits
assertEquals(expectedDimensions3, state3.getDimensions());

// Case 4: 2 thresholds, each with 8 dimensions (16 bits, already aligned)
float[][] thresholds4 = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } };
MultiBitScalarQuantizationState state4 = new MultiBitScalarQuantizationState(params, thresholds4);
int expectedDimensions4 = 16; // Already aligned to 8
assertEquals(expectedDimensions4, state4.getDimensions());

// Case 5: 2 thresholds, each with 6 dimensions (12 bits, should align to 16)
float[][] thresholds5 = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f }, { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f } };
MultiBitScalarQuantizationState state5 = new MultiBitScalarQuantizationState(params, thresholds5);
int expectedDimensions5 = 16; // The next multiple of 8 considering all bits
assertEquals(expectedDimensions5, state5.getDimensions());
}

public void testOneBitScalarQuantizationStateGetDimensionsWithDimensionNotMultipleOf8() {
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);

// Case 1: 5 dimensions (should align to 8)
float[] thresholds1 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f };
OneBitScalarQuantizationState state1 = new OneBitScalarQuantizationState(params, thresholds1);
int expectedDimensions1 = 8; // The next multiple of 8
assertEquals(expectedDimensions1, state1.getDimensions());

// Case 2: 7 dimensions (should align to 8)
float[] thresholds2 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f };
OneBitScalarQuantizationState state2 = new OneBitScalarQuantizationState(params, thresholds2);
int expectedDimensions2 = 8; // The next multiple of 8
assertEquals(expectedDimensions2, state2.getDimensions());

// Case 3: 8 dimensions (already aligned to 8)
float[] thresholds3 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f };
OneBitScalarQuantizationState state3 = new OneBitScalarQuantizationState(params, thresholds3);
int expectedDimensions3 = 8; // Already aligned to 8
assertEquals(expectedDimensions3, state3.getDimensions());

// Case 4: 10 dimensions (should align to 16)
float[] thresholds4 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f };
OneBitScalarQuantizationState state4 = new OneBitScalarQuantizationState(params, thresholds4);
int expectedDimensions4 = 16; // The next multiple of 8
assertEquals(expectedDimensions4, state4.getDimensions());

// Case 5: 16 dimensions (already aligned to 16)
float[] thresholds5 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f, 10.5f, 11.5f, 12.5f, 13.5f, 14.5f, 15.5f };
OneBitScalarQuantizationState state5 = new OneBitScalarQuantizationState(params, thresholds5);
int expectedDimensions5 = 16; // Already aligned to 16
assertEquals(expectedDimensions5, state5.getDimensions());
}

public void testMultiBitScalarQuantizationStateRamBytesUsedManualCalculation() throws IOException {
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
float[][] thresholds = { { 0.5f, 1.5f, 2.5f }, { 1.0f, 2.0f, 3.0f } };
Expand Down

0 comments on commit 76cbbbc

Please sign in to comment.