From 87db66a8233b3834fd80e86d027a1543b79f8713 Mon Sep 17 00:00:00 2001 From: Junqiu Lei Date: Mon, 22 Jul 2024 17:21:24 -0700 Subject: [PATCH] Block PQ support when data type is binary format (#1868) --- .../org/opensearch/knn/index/IndexUtil.java | 39 +++++- .../transport/TrainingModelRequest.java | 3 +- .../opensearch/knn/training/VectorReader.java | 2 +- .../org/opensearch/knn/index/FaissIT.java | 114 ------------------ .../opensearch/knn/index/IndexUtilTests.java | 54 +++++++-- .../transport/TrainingModelRequestTests.java | 6 + 6 files changed, 87 insertions(+), 131 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index 2e3f56d96..129384d18 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -87,7 +87,8 @@ public static ValidationException validateKnnField( String field, int expectedDimension, ModelDao modelDao, - VectorDataType expectedVectorDataType + VectorDataType trainRequestVectorDataType, + KNNMethodContext trainRequestKnnMethodContext ) { // Index metadata should not be null if (indexMetadata == null) { @@ -142,27 +143,53 @@ public static ValidationException validateKnnField( return exception; } - if (expectedVectorDataType != null) { - if (VectorDataType.BYTE == expectedVectorDataType) { + if (trainRequestVectorDataType != null) { + if (VectorDataType.BYTE == trainRequestVectorDataType) { exception.addValidationError( - String.format(Locale.ROOT, "vector data type \"%s\" is not supported for training.", expectedVectorDataType.getValue()) + String.format( + Locale.ROOT, + "vector data type \"%s\" is not supported for training.", + trainRequestVectorDataType.getValue() + ) ); return exception; } VectorDataType trainIndexDataType = getVectorDataTypeFromFieldMapping(fieldMap); - if (trainIndexDataType != expectedVectorDataType) { + if (trainIndexDataType != trainRequestVectorDataType) { exception.addValidationError( String.format( Locale.ROOT, "Field \"%s\" has data type %s, which is different from data type used in the training request: %s", field, trainIndexDataType.getValue(), - expectedVectorDataType.getValue() + trainRequestVectorDataType.getValue() ) ); return exception; } + + // Block binary vector data type for pq encoder + if (trainRequestKnnMethodContext != null) { + MethodComponentContext methodComponentContext = trainRequestKnnMethodContext.getMethodComponentContext(); + Map parameters = methodComponentContext.getParameters(); + + if (parameters != null && parameters.containsKey(KNNConstants.METHOD_ENCODER_PARAMETER)) { + MethodComponentContext encoder = (MethodComponentContext) parameters.get(KNNConstants.METHOD_ENCODER_PARAMETER); + if (encoder != null + && KNNConstants.ENCODER_PQ.equals(encoder.getName()) + && VectorDataType.BINARY == trainRequestVectorDataType) { + exception.addValidationError( + String.format( + Locale.ROOT, + "vector data type \"%s\" is not supported for pq encoder.", + trainRequestVectorDataType.getValue() + ) + ); + return exception; + } + } + } } // Return if dimension does not need to be checked diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 16a1a103a..f7ad997b2 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -332,7 +332,8 @@ public ActionRequestValidationException validate() { this.trainingField, this.dimension, modelDao, - this.vectorDataType + vectorDataType, + knnMethodContext ); if (fieldValidation != null) { exception = exception == null ? new ActionRequestValidationException() : exception; diff --git a/src/main/java/org/opensearch/knn/training/VectorReader.java b/src/main/java/org/opensearch/knn/training/VectorReader.java index f1fd744fd..e94d037bd 100644 --- a/src/main/java/org/opensearch/knn/training/VectorReader.java +++ b/src/main/java/org/opensearch/knn/training/VectorReader.java @@ -88,7 +88,7 @@ public void read( throw validationException; } - ValidationException fieldValidationException = IndexUtil.validateKnnField(indexMetadata, fieldName, -1, null, null); + ValidationException fieldValidationException = IndexUtil.validateKnnField(indexMetadata, fieldName, -1, null, null, null); if (fieldValidationException != null) { validationException = validationException == null ? new ValidationException() : validationException; validationException.addValidationErrors(validationException.validationErrors()); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 538fc91fd..b0509b45f 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -1712,120 +1712,6 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() { validateGraphEviction(); } - @SneakyThrows - public void testIVF_whenBinaryFormat_whenIVFPQ_thenSuccess() { - String modelId = "test-model-ivfpq-binary"; - int dimension = 8; - - String trainingIndexName = "train-index-ivfpq-binary"; - String trainingFieldName = "train-field-ivfpq-binary"; - - String trainIndexMapping = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(trainingFieldName) - .field("type", "knn_vector") - .field("dimension", dimension) - .field("data_type", VectorDataType.BINARY.getValue()) - .startObject(KNN_METHOD) - .field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName()) - .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.HAMMING.getValue()) - .field(KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(PARAMETERS) - .field(METHOD_PARAMETER_M, 24) - .field(METHOD_PARAMETER_EF_CONSTRUCTION, 128) - .endObject() - .endObject() - .endObject() - .endObject() - .endObject() - .toString(); - - createKnnIndex(trainingIndexName, trainIndexMapping); - - int trainingDataCount = 50; - bulkIngestRandomBinaryVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); - - XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder() - .startObject() - .field(TRAIN_INDEX_PARAMETER, trainingIndexName) - .field(TRAIN_FIELD_PARAMETER, trainingFieldName) - .field(DIMENSION, dimension) - .field(MODEL_DESCRIPTION, "My model description") - .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()) - .startObject(KNN_METHOD) - .field(NAME, METHOD_IVF) - .field(KNN_ENGINE, FAISS_NAME) - .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.HAMMING.getValue()) - .startObject(PARAMETERS) - .field(METHOD_PARAMETER_NPROBES, 1) - .field(METHOD_PARAMETER_NLIST, 1) - .startObject(METHOD_ENCODER_PARAMETER) - .field(NAME, ENCODER_PQ) - .startObject(PARAMETERS) - .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 8) - .field(ENCODER_PARAMETER_PQ_M, 8) - .endObject() - .endObject() - .endObject() - .endObject() - .endObject(); - - trainModel(modelId, trainModelXContentBuilder); - - // Make sure training succeeds after 30 seconds - assertTrainingSucceeds(modelId, 30, 1000); - - // Create knn index from model - String fieldName = "test-field-name-ivfpq-binary"; - String indexName = "test-index-name-ivfpq-binary"; - - String indexMapping = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field(MODEL_ID, modelId) - .endObject() - .endObject() - .endObject() - .toString(); - - createKnnIndex(indexName, getKNNDefaultIndexSettings(), indexMapping); - Integer[] vector1 = { 11 }; - Integer[] vector2 = { 22 }; - Integer[] vector3 = { 33 }; - Integer[] vector4 = { 44 }; - addKnnDoc(indexName, "1", fieldName, vector1); - addKnnDoc(indexName, "2", fieldName, vector2); - addKnnDoc(indexName, "3", fieldName, vector3); - addKnnDoc(indexName, "4", fieldName, vector4); - - Integer[] queryVector = { 15 }; - int k = 2; - - XContentBuilder queryBuilder = XContentFactory.jsonBuilder() - .startObject() - .startObject("query") - .startObject("knn") - .startObject(fieldName) - .field("vector", queryVector) - .field("k", k) - .endObject() - .endObject() - .endObject() - .endObject(); - Response searchResponse = searchKNNIndex(indexName, queryBuilder, k); - List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); - assertEquals(k, results.size()); - - deleteKNNIndex(indexName); - Thread.sleep(45 * 1000); - deleteModel(modelId); - deleteKNNIndex(trainingIndexName); - validateGraphEviction(); - } - protected void setupKNNIndexForFilterQuery() throws Exception { // Create Mappings XContentBuilder builder = XContentFactory.jsonBuilder() diff --git a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java index 1b00ecfaa..809c7d930 100644 --- a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java @@ -28,6 +28,7 @@ import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.jni.JNIService; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -38,7 +39,10 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.IndexUtil.getParametersAtLoading; @@ -117,7 +121,7 @@ public void testValidateKnnField_NestedField() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null); assertNull(e); } @@ -138,7 +142,7 @@ public void testValidateKnnField_NonNestedField() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null); assertNull(e); } @@ -158,7 +162,7 @@ public void testValidateKnnField_NonKnnField() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null); assert Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" is not of type knn_vector.;"); } @@ -182,7 +186,7 @@ public void testValidateKnnField_WrongFieldPath() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null); assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" does not exist.;")); } @@ -206,7 +210,7 @@ public void testValidateKnnField_EmptyField() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null); System.out.println(Objects.requireNonNull(e).getMessage()); @@ -223,7 +227,7 @@ public void testValidateKnnField_EmptyIndexMetadata() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null, null); assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Invalid index. Index does not contain a mapping;")); } @@ -273,7 +277,7 @@ public void testValidateKnnField_whenTrainModelUseDifferentVectorDataTypeFromTra when(indexMetadata.mapping()).thenReturn(mappingMetadata); ModelDao modelDao = mock(ModelDao.class); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BINARY); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BINARY, null); System.out.println(Objects.requireNonNull(e).getMessage()); assert Objects.requireNonNull(e) @@ -298,8 +302,7 @@ public void testValidateKnnField_whenPassByteVectorDataType_thenThrowException() when(indexMetadata.mapping()).thenReturn(mappingMetadata); ModelDao modelDao = mock(ModelDao.class); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BYTE); - System.out.println(Objects.requireNonNull(e).getMessage()); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, VectorDataType.BYTE, null); assert Objects.requireNonNull(e) .getMessage() @@ -311,4 +314,37 @@ public void testUpdateVectorDataTypeToParameters_whenVectorDataTypeIsBinary() { IndexUtil.updateVectorDataTypeToParameters(indexParams, VectorDataType.BINARY); assertEquals(VectorDataType.BINARY.getValue(), indexParams.get(VECTOR_DATA_TYPE_FIELD)); } + + public void testValidateKnnField_whenPassBinaryVectorDataTypeAndPQEncoder_thenThrowException() { + Map fieldValues = Map.of("type", "knn_vector", "dimension", 8, "data_type", "binary", "encoder", "pq"); + Map top_level_field = Map.of("top_level_field", fieldValues); + Map properties = Map.of("properties", top_level_field); + String field = "top_level_field"; + int dimension = 8; + + MappingMetadata mappingMetadata = mock(MappingMetadata.class); + when(mappingMetadata.getSourceAsMap()).thenReturn(properties); + IndexMetadata indexMetadata = mock(IndexMetadata.class); + when(indexMetadata.mapping()).thenReturn(mappingMetadata); + ModelDao modelDao = mock(ModelDao.class); + MethodComponentContext pq = new MethodComponentContext(ENCODER_PQ, Collections.emptyMap()); + KNNMethodContext knnMethodContext = new KNNMethodContext( + KNNEngine.FAISS, + SpaceType.INNER_PRODUCT, + new MethodComponentContext(METHOD_IVF, ImmutableMap.of(METHOD_ENCODER_PARAMETER, pq)) + ); + + ValidationException e = IndexUtil.validateKnnField( + indexMetadata, + field, + dimension, + modelDao, + VectorDataType.BINARY, + knnMethodContext + ); + + assert Objects.requireNonNull(e) + .getMessage() + .matches("Validation Failed: 1: vector data type \"binary\" is not supported for pq encoder.;"); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 9434a6e41..0fb478c83 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -255,6 +255,7 @@ public void testValidation_invalid_invalidMethodContext() { when(knnMethodContext.validate()).thenReturn(validationException); when(knnMethodContext.isTrainingRequired()).thenReturn(false); + when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; @@ -454,6 +455,7 @@ public void testValidation_invalid_dimensionDoesNotMatch() { when(knnMethodContext.validate()).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); + when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; @@ -511,6 +513,7 @@ public void testValidation_invalid_preferredNodeDoesNotExist() { KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); when(knnMethodContext.validate()).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); + when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; @@ -573,6 +576,7 @@ public void testValidation_invalid_descriptionToLong() { KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); when(knnMethodContext.validate()).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); + when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; @@ -623,6 +627,7 @@ public void testValidation_valid_trainingIndexBuiltFromMethod() { KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); when(knnMethodContext.validate()).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); + when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field"; @@ -660,6 +665,7 @@ public void testValidation_valid_trainingIndexBuiltFromModel() { KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); when(knnMethodContext.validate()).thenReturn(null); when(knnMethodContext.isTrainingRequired()).thenReturn(true); + when(knnMethodContext.getMethodComponentContext()).thenReturn(MethodComponentContext.EMPTY); int dimension = 10; String trainingIndex = "test-training-index"; String trainingField = "test-training-field";