Skip to content

Commit

Permalink
Add support for Lucene SQ 4 bits
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Nov 7, 2024
1 parent 64bae92 commit 7580229
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
import org.opensearch.knn.index.mapper.Mode;

import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.index.engine.lucene.LuceneHNSWMethod.HNSW_METHOD_COMPONENT;
Expand Down Expand Up @@ -60,6 +64,7 @@ public ResolvedMethodContext resolveMethod(

protected void resolveEncoder(KNNMethodContext resolvedKNNMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
if (shouldEncoderBeResolved(resolvedKNNMethodContext, knnMethodConfigContext) == false) {
validateEncoderDimension(resolvedKNNMethodContext, knnMethodConfigContext);
return;
}

Expand Down Expand Up @@ -94,6 +99,30 @@ private void validateConfig(KNNMethodConfigContext knnMethodConfigContext, boole
}
}

private void validateEncoderDimension(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext) {
String encoderName = getEncoderName(knnMethodContext);
if (encoderName == null || ENCODER_SQ.equals(encoderName) == false) {
return;
}

MethodComponentContext encoderMethodComponentContext = getEncoderComponentContext(knnMethodContext);
if (encoderMethodComponentContext.getParameters().containsKey(LUCENE_SQ_BITS)
&& encoderMethodComponentContext.getParameters().get(LUCENE_SQ_BITS).equals(4)
&& knnMethodConfigContext.getDimension() % 2 != 0) {
ValidationException validationException = new ValidationException();
validationException.addValidationError(
String.format(
Locale.ROOT,
"Odd vector dimension is not supported when [%s] is set to [4] for [%s] engine with [%s] encoder",
LUCENE_SQ_BITS,
LUCENE_NAME,
ENCODER_SQ
)
);
throw validationException;
}
}

private CompressionLevel getDefaultCompressionLevel(KNNMethodConfigContext knnMethodConfigContext) {
if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel())) {
return knnMethodConfigContext.getCompressionLevel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.engine.lucene;

import com.google.common.collect.ImmutableSet;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.Encoder;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
Expand All @@ -31,7 +32,7 @@
public class LuceneSQEncoder implements Encoder {
private static final Set<VectorDataType> SUPPORTED_DATA_TYPES = ImmutableSet.of(VectorDataType.FLOAT);

private final static List<Integer> LUCENE_SQ_BITS_SUPPORTED = List.of(7);
private final static List<Integer> LUCENE_SQ_BITS_SUPPORTED = List.of(4, 7);
private final static MethodComponent METHOD_COMPONENT = MethodComponent.Builder.builder(ENCODER_SQ)
.addSupportedDataTypes(SUPPORTED_DATA_TYPES)
.addParameter(
Expand All @@ -58,7 +59,24 @@ public CompressionLevel calculateCompressionLevel(
MethodComponentContext methodComponentContext,
KNNMethodConfigContext knnMethodConfigContext
) {
// Hard coding to 4x for now, given thats all that is supported.
if (methodComponentContext.getParameters().containsKey(LUCENE_SQ_BITS) == false) {
return CompressionLevel.x4;
}

// Map the number of bits passed in, back to the compression level
Object value = methodComponentContext.getParameters().get(LUCENE_SQ_BITS);
ValidationException validationException = METHOD_COMPONENT.getParameters()
.get(LUCENE_SQ_BITS)
.validate(value, knnMethodConfigContext);
if (validationException != null) {
throw validationException;
}

Integer bitCount = (Integer) value;
if (bitCount == 4) {
return CompressionLevel.NOT_CONFIGURED;
}
// Return 4x compression for 7 bits
return CompressionLevel.x4;
}
}
60 changes: 36 additions & 24 deletions src/test/java/org/opensearch/knn/index/LuceneEngineIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ public class LuceneEngineIT extends KNNRestTestCase {
private static final String INTEGER_FIELD_NAME = "int_field";
private static final String FILED_TYPE_INTEGER = "integer";
private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field";
private final static List<Integer> LUCENE_SQ_BITS_SUPPORTED = ImmutableList.of(4, 7);
private static final int DIMENSION_SQ = 2;

@After
public final void cleanUp() throws IOException {
Expand Down Expand Up @@ -592,16 +594,30 @@ public void testSQ_withInvalidParams_thenThrowException() {
);
}

@SneakyThrows
public void testSQ_4bits_withOddDimension_thenThrowException() {
expectThrows(
ResponseException.class,
() -> createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
SpaceType.L2,
VectorDataType.FLOAT,
4,
MINIMUM_CONFIDENCE_INTERVAL
)
);
}

@SneakyThrows
public void testAddDocWithSQEncoder() {
createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
DIMENSION_SQ,
SpaceType.L2,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
LUCENE_SQ_BITS_SUPPORTED.get(random().nextInt(LUCENE_SQ_BITS_SUPPORTED.size())),
MAXIMUM_CONFIDENCE_INTERVAL
);
Float[] vector = new Float[] { 2.0f, 4.5f, 6.5f };
Float[] vector = new Float[] { 2.0f, 4.5f };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);

refreshIndex(INDEX_NAME);
Expand All @@ -611,16 +627,16 @@ public void testAddDocWithSQEncoder() {
@SneakyThrows
public void testUpdateDocWithSQEncoder() {
createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
DIMENSION_SQ,
SpaceType.INNER_PRODUCT,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
LUCENE_SQ_BITS_SUPPORTED.get(random().nextInt(LUCENE_SQ_BITS_SUPPORTED.size())),
MAXIMUM_CONFIDENCE_INTERVAL
);
Float[] vector = { 6.0f, 6.0f, 7.0f };
Float[] vector = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);

Float[] updatedVector = { 8.0f, 8.0f, 8.0f };
Float[] updatedVector = { 8.0f, 8.0f };
updateKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, updatedVector);

refreshIndex(INDEX_NAME);
Expand All @@ -630,13 +646,13 @@ public void testUpdateDocWithSQEncoder() {
@SneakyThrows
public void testDeleteDocWithSQEncoder() {
createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
DIMENSION_SQ,
SpaceType.INNER_PRODUCT,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
LUCENE_SQ_BITS_SUPPORTED.get(random().nextInt(LUCENE_SQ_BITS_SUPPORTED.size())),
MAXIMUM_CONFIDENCE_INTERVAL
);
Float[] vector = { 6.0f, 6.0f, 7.0f };
Float[] vector = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, DOC_ID, FIELD_NAME, vector);

deleteKnnDoc(INDEX_NAME, DOC_ID);
Expand All @@ -648,16 +664,16 @@ public void testDeleteDocWithSQEncoder() {
@SneakyThrows
public void testIndexingAndQueryingWithSQEncoder() {
createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
DIMENSION_SQ,
SpaceType.INNER_PRODUCT,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
LUCENE_SQ_BITS_SUPPORTED.get(random().nextInt(LUCENE_SQ_BITS_SUPPORTED.size())),
MAXIMUM_CONFIDENCE_INTERVAL
);

int numDocs = 10;
for (int i = 0; i < numDocs; i++) {
float[] indexVector = new float[DIMENSION];
float[] indexVector = new float[DIMENSION_SQ];
Arrays.fill(indexVector, (float) i);
addKnnDocWithAttributes(INDEX_NAME, Integer.toString(i), FIELD_NAME, indexVector, ImmutableMap.of("rating", String.valueOf(i)));
}
Expand All @@ -666,7 +682,7 @@ public void testIndexingAndQueryingWithSQEncoder() {
refreshAllNonSystemIndices();
assertEquals(numDocs, getDocCount(INDEX_NAME));

float[] queryVector = new float[DIMENSION];
float[] queryVector = new float[DIMENSION_SQ];
Arrays.fill(queryVector, (float) numDocs);
int k = 10;

Expand All @@ -680,24 +696,20 @@ public void testIndexingAndQueryingWithSQEncoder() {

public void testQueryWithFilterUsingSQEncoder() throws Exception {
createKnnIndexMappingWithLuceneEngineAndSQEncoder(
DIMENSION,
DIMENSION_SQ,
SpaceType.INNER_PRODUCT,
VectorDataType.FLOAT,
LUCENE_SQ_DEFAULT_BITS,
LUCENE_SQ_BITS_SUPPORTED.get(random().nextInt(LUCENE_SQ_BITS_SUPPORTED.size())),
MAXIMUM_CONFIDENCE_INTERVAL
);

addKnnDocWithAttributes(
DOC_ID,
new float[] { 6.0f, 7.9f, 3.1f },
ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet")
);
addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f, 4.8f }, ImmutableMap.of(COLOR_FIELD_NAME, "green"));
addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f, 7.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));
addKnnDocWithAttributes(DOC_ID, new float[] { 6.0f, 7.9f }, ImmutableMap.of(COLOR_FIELD_NAME, "red", TASTE_FIELD_NAME, "sweet"));
addKnnDocWithAttributes(DOC_ID_2, new float[] { 3.2f, 2.1f }, ImmutableMap.of(COLOR_FIELD_NAME, "green"));
addKnnDocWithAttributes(DOC_ID_3, new float[] { 4.1f, 5.0f }, ImmutableMap.of(COLOR_FIELD_NAME, "red"));

refreshIndex(INDEX_NAME);

final float[] searchVector = { 6.0f, 6.0f, 4.1f };
final float[] searchVector = { 6.0f, 6.0f };
List<String> expectedDocIdsKGreaterThanFilterResult = Arrays.asList(DOC_ID, DOC_ID_3);
List<String> expectedDocIdsKLimitsFilterResult = Arrays.asList(DOC_ID);
validateQueryResultsWithFilters(searchVector, 5, 1, expectedDocIdsKGreaterThanFilterResult, expectedDocIdsKLimitsFilterResult);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,22 @@
package org.opensearch.knn.index.engine.lucene;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.mapper.CompressionLevel;

import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.LUCENE_SQ_BITS;

public class LuceneSQEncoderTests extends KNNTestCase {
public void testCalculateCompressionLevel() {
LuceneSQEncoder encoder = new LuceneSQEncoder();
assertEquals(CompressionLevel.x4, encoder.calculateCompressionLevel(null, null));
assertEquals(CompressionLevel.NOT_CONFIGURED, encoder.calculateCompressionLevel(generateMethodComponentContext(4), null));
assertEquals(CompressionLevel.x4, encoder.calculateCompressionLevel(generateMethodComponentContext(7), null));
}

private MethodComponentContext generateMethodComponentContext(int bitCount) {
return new MethodComponentContext(ENCODER_SQ, Map.of(LUCENE_SQ_BITS, bitCount));
}
}

0 comments on commit 7580229

Please sign in to comment.