Skip to content

Commit

Permalink
Block PQ support when data type is binary format (#1868)
Browse files Browse the repository at this point in the history
  • Loading branch information
junqiu-lei authored Jul 23, 2024
1 parent aa5312e commit 87db66a
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 131 deletions.
39 changes: 33 additions & 6 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ public static ValidationException validateKnnField(
String field,
int expectedDimension,
ModelDao modelDao,
VectorDataType expectedVectorDataType
VectorDataType trainRequestVectorDataType,
KNNMethodContext trainRequestKnnMethodContext
) {
// Index metadata should not be null
if (indexMetadata == null) {
Expand Down Expand Up @@ -142,27 +143,53 @@ public static ValidationException validateKnnField(
return exception;
}

if (expectedVectorDataType != null) {
if (VectorDataType.BYTE == expectedVectorDataType) {
if (trainRequestVectorDataType != null) {
if (VectorDataType.BYTE == trainRequestVectorDataType) {
exception.addValidationError(
String.format(Locale.ROOT, "vector data type \"%s\" is not supported for training.", expectedVectorDataType.getValue())
String.format(
Locale.ROOT,
"vector data type \"%s\" is not supported for training.",
trainRequestVectorDataType.getValue()
)
);
return exception;
}
VectorDataType trainIndexDataType = getVectorDataTypeFromFieldMapping(fieldMap);

if (trainIndexDataType != expectedVectorDataType) {
if (trainIndexDataType != trainRequestVectorDataType) {
exception.addValidationError(
String.format(
Locale.ROOT,
"Field \"%s\" has data type %s, which is different from data type used in the training request: %s",
field,
trainIndexDataType.getValue(),
expectedVectorDataType.getValue()
trainRequestVectorDataType.getValue()
)
);
return exception;
}

// Block binary vector data type for pq encoder
if (trainRequestKnnMethodContext != null) {
MethodComponentContext methodComponentContext = trainRequestKnnMethodContext.getMethodComponentContext();
Map<String, Object> parameters = methodComponentContext.getParameters();

if (parameters != null && parameters.containsKey(KNNConstants.METHOD_ENCODER_PARAMETER)) {
MethodComponentContext encoder = (MethodComponentContext) parameters.get(KNNConstants.METHOD_ENCODER_PARAMETER);
if (encoder != null
&& KNNConstants.ENCODER_PQ.equals(encoder.getName())
&& VectorDataType.BINARY == trainRequestVectorDataType) {
exception.addValidationError(
String.format(
Locale.ROOT,
"vector data type \"%s\" is not supported for pq encoder.",
trainRequestVectorDataType.getValue()
)
);
return exception;
}
}
}
}

// Return if dimension does not need to be checked
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ public ActionRequestValidationException validate() {
this.trainingField,
this.dimension,
modelDao,
this.vectorDataType
vectorDataType,
knnMethodContext
);
if (fieldValidation != null) {
exception = exception == null ? new ActionRequestValidationException() : exception;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public <T> void read(
throw validationException;
}

ValidationException fieldValidationException = IndexUtil.validateKnnField(indexMetadata, fieldName, -1, null, null);
ValidationException fieldValidationException = IndexUtil.validateKnnField(indexMetadata, fieldName, -1, null, null, null);
if (fieldValidationException != null) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationErrors(validationException.validationErrors());
Expand Down
114 changes: 0 additions & 114 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -1712,120 +1712,6 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() {
validateGraphEviction();
}

@SneakyThrows
public void testIVF_whenBinaryFormat_whenIVFPQ_thenSuccess() {
String modelId = "test-model-ivfpq-binary";
int dimension = 8;

String trainingIndexName = "train-index-ivfpq-binary";
String trainingFieldName = "train-field-ivfpq-binary";

String trainIndexMapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(trainingFieldName)
.field("type", "knn_vector")
.field("dimension", dimension)
.field("data_type", VectorDataType.BINARY.getValue())
.startObject(KNN_METHOD)
.field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName())
.field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.HAMMING.getValue())
.field(KNN_ENGINE, KNNEngine.FAISS.getName())
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_M, 24)
.field(METHOD_PARAMETER_EF_CONSTRUCTION, 128)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.toString();

createKnnIndex(trainingIndexName, trainIndexMapping);

int trainingDataCount = 50;
bulkIngestRandomBinaryVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(TRAIN_INDEX_PARAMETER, trainingIndexName)
.field(TRAIN_FIELD_PARAMETER, trainingFieldName)
.field(DIMENSION, dimension)
.field(MODEL_DESCRIPTION, "My model description")
.field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue())
.startObject(KNN_METHOD)
.field(NAME, METHOD_IVF)
.field(KNN_ENGINE, FAISS_NAME)
.field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.HAMMING.getValue())
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_NPROBES, 1)
.field(METHOD_PARAMETER_NLIST, 1)
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_PQ)
.startObject(PARAMETERS)
.field(ENCODER_PARAMETER_PQ_CODE_SIZE, 8)
.field(ENCODER_PARAMETER_PQ_M, 8)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

trainModel(modelId, trainModelXContentBuilder);

// Make sure training succeeds after 30 seconds
assertTrainingSucceeds(modelId, 30, 1000);

// Create knn index from model
String fieldName = "test-field-name-ivfpq-binary";
String indexName = "test-index-name-ivfpq-binary";

String indexMapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field(MODEL_ID, modelId)
.endObject()
.endObject()
.endObject()
.toString();

createKnnIndex(indexName, getKNNDefaultIndexSettings(), indexMapping);
Integer[] vector1 = { 11 };
Integer[] vector2 = { 22 };
Integer[] vector3 = { 33 };
Integer[] vector4 = { 44 };
addKnnDoc(indexName, "1", fieldName, vector1);
addKnnDoc(indexName, "2", fieldName, vector2);
addKnnDoc(indexName, "3", fieldName, vector3);
addKnnDoc(indexName, "4", fieldName, vector4);

Integer[] queryVector = { 15 };
int k = 2;

XContentBuilder queryBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(fieldName)
.field("vector", queryVector)
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response searchResponse = searchKNNIndex(indexName, queryBuilder, k);
List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName);
assertEquals(k, results.size());

deleteKNNIndex(indexName);
Thread.sleep(45 * 1000);
deleteModel(modelId);
deleteKNNIndex(trainingIndexName);
validateGraphEviction();
}

protected void setupKNNIndexForFilterQuery() throws Exception {
// Create Mappings
XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down
54 changes: 45 additions & 9 deletions src/test/java/org/opensearch/knn/index/IndexUtilTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.jni.JNIService;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
Expand All @@ -38,7 +39,10 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.IndexUtil.getParametersAtLoading;
Expand Down Expand Up @@ -117,7 +121,7 @@ public void testValidateKnnField_NestedField() {
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null);

assertNull(e);
}
Expand All @@ -138,7 +142,7 @@ public void testValidateKnnField_NonNestedField() {
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null);

assertNull(e);
}
Expand All @@ -158,7 +162,7 @@ public void testValidateKnnField_NonKnnField() {
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null);

assert Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" is not of type knn_vector.;");
}
Expand All @@ -182,7 +186,7 @@ public void testValidateKnnField_WrongFieldPath() {
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null);

assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" does not exist.;"));
}
Expand All @@ -206,7 +210,7 @@ public void testValidateKnnField_EmptyField() {
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null);

System.out.println(Objects.requireNonNull(e).getMessage());

Expand All @@ -223,7 +227,7 @@ public void testValidateKnnField_EmptyIndexMetadata() {
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null);

assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Invalid index. Index does not contain a mapping;"));
}
Expand Down Expand Up @@ -273,7 +277,7 @@ public void testValidateKnnField_whenTrainModelUseDifferentVectorDataTypeFromTra
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BINARY);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BINARY, null);
System.out.println(Objects.requireNonNull(e).getMessage());

assert Objects.requireNonNull(e)
Expand All @@ -298,8 +302,7 @@ public void testValidateKnnField_whenPassByteVectorDataType_thenThrowException()
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BYTE);
System.out.println(Objects.requireNonNull(e).getMessage());
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BYTE, null);

assert Objects.requireNonNull(e)
.getMessage()
Expand All @@ -311,4 +314,37 @@ public void testUpdateVectorDataTypeToParameters_whenVectorDataTypeIsBinary() {
IndexUtil.updateVectorDataTypeToParameters(indexParams, VectorDataType.BINARY);
assertEquals(VectorDataType.BINARY.getValue(), indexParams.get(VECTOR_DATA_TYPE_FIELD));
}

public void testValidateKnnField_whenPassBinaryVectorDataTypeAndPQEncoder_thenThrowException() {
Map<String, Object> fieldValues = Map.of("type", "knn_vector", "dimension", 8, "data_type", "binary", "encoder", "pq");
Map<String, Object> top_level_field = Map.of("top_level_field", fieldValues);
Map<String, Object> properties = Map.of("properties", top_level_field);
String field = "top_level_field";
int dimension = 8;

MappingMetadata mappingMetadata = mock(MappingMetadata.class);
when(mappingMetadata.getSourceAsMap()).thenReturn(properties);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);
MethodComponentContext pq = new MethodComponentContext(ENCODER_PQ, Collections.emptyMap());
KNNMethodContext knnMethodContext = new KNNMethodContext(
KNNEngine.FAISS,
SpaceType.INNER_PRODUCT,
new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_ENCODER_PARAMETER, pq))
);

ValidationException e = IndexUtil.validateKnnField(
indexMetadata,
field,
dimension,
modelDao,
VectorDataType.BINARY,
knnMethodContext
);

assert Objects.requireNonNull(e)
.getMessage()
.matches("Validation Failed: 1: vector data type \"binary\" is not supported for pq encoder.;");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ public void testValidation_invalid_invalidMethodContext() {
when(knnMethodContext.validate()).thenReturn(validationException);

when(knnMethodContext.isTrainingRequired()).thenReturn(false);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand Down Expand Up @@ -454,6 +455,7 @@ public void testValidation_invalid_dimensionDoesNotMatch() {
when(knnMethodContext.validate()).thenReturn(null);

when(knnMethodContext.isTrainingRequired()).thenReturn(true);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand Down Expand Up @@ -511,6 +513,7 @@ public void testValidation_invalid_preferredNodeDoesNotExist() {
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.validate()).thenReturn(null);
when(knnMethodContext.isTrainingRequired()).thenReturn(true);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand Down Expand Up @@ -573,6 +576,7 @@ public void testValidation_invalid_descriptionToLong() {
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.validate()).thenReturn(null);
when(knnMethodContext.isTrainingRequired()).thenReturn(true);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand Down Expand Up @@ -623,6 +627,7 @@ public void testValidation_valid_trainingIndexBuiltFromMethod() {
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.validate()).thenReturn(null);
when(knnMethodContext.isTrainingRequired()).thenReturn(true);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand Down Expand Up @@ -660,6 +665,7 @@ public void testValidation_valid_trainingIndexBuiltFromModel() {
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.validate()).thenReturn(null);
when(knnMethodContext.isTrainingRequired()).thenReturn(true);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand Down

0 comments on commit 87db66a

Please sign in to comment.