Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block PQ support when data type is binary format #1868

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ public static ValidationException validateKnnField(
String field,
int expectedDimension,
ModelDao modelDao,
VectorDataType expectedVectorDataType
VectorDataType trainRequestVectorDataType,
KNNMethodContext trainRequestKnnMethodContext
) {
// Index metadata should not be null
if (indexMetadata == null) {
Expand Down Expand Up @@ -142,27 +143,53 @@ public static ValidationException validateKnnField(
return exception;
}

if (expectedVectorDataType != null) {
if (VectorDataType.BYTE == expectedVectorDataType) {
if (trainRequestVectorDataType != null) {
if (VectorDataType.BYTE == trainRequestVectorDataType) {
exception.addValidationError(
String.format(Locale.ROOT, "vector data type \"%s\" is not supported for training.", expectedVectorDataType.getValue())
String.format(
Locale.ROOT,
"vector data type \"%s\" is not supported for training.",
trainRequestVectorDataType.getValue()
)
);
return exception;
}
VectorDataType trainIndexDataType = getVectorDataTypeFromFieldMapping(fieldMap);

if (trainIndexDataType != expectedVectorDataType) {
if (trainIndexDataType != trainRequestVectorDataType) {
exception.addValidationError(
String.format(
Locale.ROOT,
"Field \"%s\" has data type %s, which is different from data type used in the training request: %s",
field,
trainIndexDataType.getValue(),
expectedVectorDataType.getValue()
trainRequestVectorDataType.getValue()
)
);
return exception;
}

// Block binary vector data type for pq encoder
if (trainRequestKnnMethodContext != null) {
MethodComponentContext methodComponentContext = trainRequestKnnMethodContext.getMethodComponentContext();
Map<String, Object> parameters = methodComponentContext.getParameters();

if (parameters != null && parameters.containsKey(KNNConstants.METHOD_ENCODER_PARAMETER)) {
MethodComponentContext encoder = (MethodComponentContext) parameters.get(KNNConstants.METHOD_ENCODER_PARAMETER);
if (encoder != null
&& KNNConstants.ENCODER_PQ.equals(encoder.getName())
&& VectorDataType.BINARY == trainRequestVectorDataType) {
exception.addValidationError(
String.format(
Locale.ROOT,
"vector data type \"%s\" is not supported for pq encoder.",
trainRequestVectorDataType.getValue()
)
);
return exception;
}
}
}
}

// Return if dimension does not need to be checked
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ public ActionRequestValidationException validate() {
this.trainingField,
this.dimension,
modelDao,
this.vectorDataType
vectorDataType,
knnMethodContext
);
if (fieldValidation != null) {
exception = exception == null ? new ActionRequestValidationException() : exception;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public <T> void read(
throw validationException;
}

ValidationException fieldValidationException = IndexUtil.validateKnnField(indexMetadata, fieldName, -1, null, null);
ValidationException fieldValidationException = IndexUtil.validateKnnField(indexMetadata, fieldName, -1, null, null, null);
if (fieldValidationException != null) {
validationException = validationException == null ? new ValidationException() : validationException;
validationException.addValidationErrors(validationException.validationErrors());
Expand Down
114 changes: 0 additions & 114 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -1712,120 +1712,6 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() {
validateGraphEviction();
}

@SneakyThrows
public void testIVF_whenBinaryFormat_whenIVFPQ_thenSuccess() {
String modelId = "test-model-ivfpq-binary";
int dimension = 8;

String trainingIndexName = "train-index-ivfpq-binary";
String trainingFieldName = "train-field-ivfpq-binary";

String trainIndexMapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(trainingFieldName)
.field("type", "knn_vector")
.field("dimension", dimension)
.field("data_type", VectorDataType.BINARY.getValue())
.startObject(KNN_METHOD)
.field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName())
.field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.HAMMING.getValue())
.field(KNN_ENGINE, KNNEngine.FAISS.getName())
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_M, 24)
.field(METHOD_PARAMETER_EF_CONSTRUCTION, 128)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject()
.toString();

createKnnIndex(trainingIndexName, trainIndexMapping);

int trainingDataCount = 50;
bulkIngestRandomBinaryVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(TRAIN_INDEX_PARAMETER, trainingIndexName)
.field(TRAIN_FIELD_PARAMETER, trainingFieldName)
.field(DIMENSION, dimension)
.field(MODEL_DESCRIPTION, "My model description")
.field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue())
.startObject(KNN_METHOD)
.field(NAME, METHOD_IVF)
.field(KNN_ENGINE, FAISS_NAME)
.field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.HAMMING.getValue())
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_NPROBES, 1)
.field(METHOD_PARAMETER_NLIST, 1)
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, ENCODER_PQ)
.startObject(PARAMETERS)
.field(ENCODER_PARAMETER_PQ_CODE_SIZE, 8)
.field(ENCODER_PARAMETER_PQ_M, 8)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

trainModel(modelId, trainModelXContentBuilder);

// Make sure training succeeds after 30 seconds
assertTrainingSucceeds(modelId, 30, 1000);

// Create knn index from model
String fieldName = "test-field-name-ivfpq-binary";
String indexName = "test-index-name-ivfpq-binary";

String indexMapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field(MODEL_ID, modelId)
.endObject()
.endObject()
.endObject()
.toString();

createKnnIndex(indexName, getKNNDefaultIndexSettings(), indexMapping);
Integer[] vector1 = { 11 };
Integer[] vector2 = { 22 };
Integer[] vector3 = { 33 };
Integer[] vector4 = { 44 };
addKnnDoc(indexName, "1", fieldName, vector1);
addKnnDoc(indexName, "2", fieldName, vector2);
addKnnDoc(indexName, "3", fieldName, vector3);
addKnnDoc(indexName, "4", fieldName, vector4);

Integer[] queryVector = { 15 };
int k = 2;

XContentBuilder queryBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject("query")
.startObject("knn")
.startObject(fieldName)
.field("vector", queryVector)
.field("k", k)
.endObject()
.endObject()
.endObject()
.endObject();
Response searchResponse = searchKNNIndex(indexName, queryBuilder, k);
List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName);
assertEquals(k, results.size());

deleteKNNIndex(indexName);
Thread.sleep(45 * 1000);
deleteModel(modelId);
deleteKNNIndex(trainingIndexName);
validateGraphEviction();
}

protected void setupKNNIndexForFilterQuery() throws Exception {
// Create Mappings
XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down
54 changes: 45 additions & 9 deletions src/test/java/org/opensearch/knn/index/IndexUtilTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.jni.JNIService;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
Expand All @@ -38,7 +39,10 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.IndexUtil.getParametersAtLoading;
Expand Down Expand Up @@ -117,7 +121,7 @@ public void testValidateKnnField_NestedField() {
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null);

assertNull(e);
}
Expand All @@ -138,7 +142,7 @@ public void testValidateKnnField_NonNestedField() {
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null);

assertNull(e);
}
Expand All @@ -158,7 +162,7 @@ public void testValidateKnnField_NonKnnField() {
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null);

assert Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" is not of type knn_vector.;");
}
Expand All @@ -182,7 +186,7 @@ public void testValidateKnnField_WrongFieldPath() {
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null);

assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" does not exist.;"));
}
Expand All @@ -206,7 +210,7 @@ public void testValidateKnnField_EmptyField() {
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null);

System.out.println(Objects.requireNonNull(e).getMessage());

Expand All @@ -223,7 +227,7 @@ public void testValidateKnnField_EmptyIndexMetadata() {
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null);

assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Invalid index. Index does not contain a mapping;"));
}
Expand Down Expand Up @@ -273,7 +277,7 @@ public void testValidateKnnField_whenTrainModelUseDifferentVectorDataTypeFromTra
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BINARY);
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BINARY, null);
System.out.println(Objects.requireNonNull(e).getMessage());

assert Objects.requireNonNull(e)
Expand All @@ -298,8 +302,7 @@ public void testValidateKnnField_whenPassByteVectorDataType_thenThrowException()
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);

ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BYTE);
System.out.println(Objects.requireNonNull(e).getMessage());
ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BYTE, null);

assert Objects.requireNonNull(e)
.getMessage()
Expand All @@ -311,4 +314,37 @@ public void testUpdateVectorDataTypeToParameters_whenVectorDataTypeIsBinary() {
IndexUtil.updateVectorDataTypeToParameters(indexParams, VectorDataType.BINARY);
assertEquals(VectorDataType.BINARY.getValue(), indexParams.get(VECTOR_DATA_TYPE_FIELD));
}

public void testValidateKnnField_whenPassBinaryVectorDataTypeAndPQEncoder_thenThrowException() {
Map<String, Object> fieldValues = Map.of("type", "knn_vector", "dimension", 8, "data_type", "binary", "encoder", "pq");
Map<String, Object> top_level_field = Map.of("top_level_field", fieldValues);
Map<String, Object> properties = Map.of("properties", top_level_field);
String field = "top_level_field";
int dimension = 8;

MappingMetadata mappingMetadata = mock(MappingMetadata.class);
when(mappingMetadata.getSourceAsMap()).thenReturn(properties);
IndexMetadata indexMetadata = mock(IndexMetadata.class);
when(indexMetadata.mapping()).thenReturn(mappingMetadata);
ModelDao modelDao = mock(ModelDao.class);
MethodComponentContext pq = new MethodComponentContext(ENCODER_PQ, Collections.emptyMap());
KNNMethodContext knnMethodContext = new KNNMethodContext(
KNNEngine.FAISS,
SpaceType.INNER_PRODUCT,
new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_ENCODER_PARAMETER, pq))
);

ValidationException e = IndexUtil.validateKnnField(
indexMetadata,
field,
dimension,
modelDao,
VectorDataType.BINARY,
knnMethodContext
);

assert Objects.requireNonNull(e)
.getMessage()
.matches("Validation Failed: 1: vector data type \"binary\" is not supported for pq encoder.;");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ public void testValidation_invalid_invalidMethodContext() {
when(knnMethodContext.validate()).thenReturn(validationException);

when(knnMethodContext.isTrainingRequired()).thenReturn(false);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand Down Expand Up @@ -454,6 +455,7 @@ public void testValidation_invalid_dimensionDoesNotMatch() {
when(knnMethodContext.validate()).thenReturn(null);

when(knnMethodContext.isTrainingRequired()).thenReturn(true);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand Down Expand Up @@ -511,6 +513,7 @@ public void testValidation_invalid_preferredNodeDoesNotExist() {
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.validate()).thenReturn(null);
when(knnMethodContext.isTrainingRequired()).thenReturn(true);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand Down Expand Up @@ -573,6 +576,7 @@ public void testValidation_invalid_descriptionToLong() {
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.validate()).thenReturn(null);
when(knnMethodContext.isTrainingRequired()).thenReturn(true);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand Down Expand Up @@ -623,6 +627,7 @@ public void testValidation_valid_trainingIndexBuiltFromMethod() {
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.validate()).thenReturn(null);
when(knnMethodContext.isTrainingRequired()).thenReturn(true);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand Down Expand Up @@ -660,6 +665,7 @@ public void testValidation_valid_trainingIndexBuiltFromModel() {
KNNMethodContext knnMethodContext = mock(KNNMethodContext.class);
when(knnMethodContext.validate()).thenReturn(null);
when(knnMethodContext.isTrainingRequired()).thenReturn(true);
when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY);
int dimension = 10;
String trainingIndex = "test-training-index";
String trainingField = "test-training-field";
Expand Down
Loading