From 011775a208002547c9080fc9423399ea500faa52 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Wed, 8 May 2024 15:53:56 -0700 Subject: [PATCH] Block commas in model description (#1692) * Block commas in model description Signed-off-by: Ryan Bogan * Add changelog entry Signed-off-by: Ryan Bogan * Add check in rest handler Signed-off-by: Ryan Bogan * Extract if statement into ModelUtil method Signed-off-by: Ryan Bogan * Remove ingestion from integ test Signed-off-by: Ryan Bogan --------- Signed-off-by: Ryan Bogan --- CHANGELOG.md | 1 + .../opensearch/knn/indices/ModelMetadata.java | 2 + .../org/opensearch/knn/indices/ModelUtil.java | 6 ++ .../plugin/rest/RestTrainModelHandler.java | 2 + .../knn/indices/ModelMetadataTests.java | 60 +++++++++++++----- .../action/RestTrainModelHandlerIT.java | 61 +++++++++++++++++++ 6 files changed, 118 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 708ce86ff..de4319adf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index fa88c8416..f3a5506cd 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -67,6 +67,7 @@ public ModelMetadata(StreamInput in) throws IOException { // Description and error may be empty. However, reading the string will work as long as they are not null // which is checked in constructor and setters this.description = in.readString(); + ModelUtil.blockCommasInModelDescription(this.description); this.error = in.readString(); if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) { @@ -123,6 +124,7 @@ public ModelMetadata( this.state = new AtomicReference<>(Objects.requireNonNull(modelState, "modelState must not be null")); this.timestamp = Objects.requireNonNull(timestamp, "timestamp must not be null"); this.description = Objects.requireNonNull(description, "description must not be null"); + ModelUtil.blockCommasInModelDescription(this.description); this.error = Objects.requireNonNull(error, "error must not be null"); this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null"); this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null"); diff --git a/src/main/java/org/opensearch/knn/indices/ModelUtil.java b/src/main/java/org/opensearch/knn/indices/ModelUtil.java index 3daaed138..4c6230a46 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelUtil.java +++ b/src/main/java/org/opensearch/knn/indices/ModelUtil.java @@ -16,6 +16,12 @@ */ public class ModelUtil { + public static void blockCommasInModelDescription(String description) { + if (description.contains(",")) { + throw new IllegalArgumentException("Model description cannot contain any commas: ','"); + } + } + public static boolean isModelPresent(ModelMetadata modelMetadata) { return modelMetadata != null; } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index a4a0de5de..ebbd7fa9b 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -16,6 +16,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.TrainingJobRouterAction; import org.opensearch.knn.plugin.transport.TrainingModelRequest; @@ -104,6 +105,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr searchSize = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false); } else if (MODEL_DESCRIPTION.equals(fieldName) && ensureNotSet(fieldName, description)) { description = parser.textOrNull(); + ModelUtil.blockCommasInModelDescription(description); } else { throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter."); } diff --git a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java index da56a8421..74715671f 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java @@ -608,20 +608,7 @@ public void testFromResponseMap() throws IOException { String description = "test-description"; String error = "test-error"; String nodeAssignment = "test-node"; - Map nestedParameters = new HashMap() { - { - put("testNestedKey1", "testNestedString"); - put("testNestedKey2", 1); - } - }; - Map parameters = new HashMap<>() { - { - put("testKey1", "testString"); - put("testKey2", 0); - put("testKey3", new MethodComponentContext("ivf", nestedParameters)); - } - }; - MethodComponentContext methodComponentContext = new MethodComponentContext("hnsw", parameters); + MethodComponentContext methodComponentContext = getMethodComponentContext(); MethodComponentContext emptyMethodComponentContext = MethodComponentContext.EMPTY; ModelMetadata expected = new ModelMetadata( @@ -667,6 +654,51 @@ public void testFromResponseMap() throws IOException { metadataAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, null); metadataAsMap.put(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT, null); assertEquals(expected2, fromMap); + } + + public void testBlockCommasInDescription() { + KNNEngine knnEngine = KNNEngine.DEFAULT; + SpaceType spaceType = SpaceType.L2; + int dimension = 128; + ModelState modelState = ModelState.TRAINING; + String timestamp = ZonedDateTime.now(ZoneOffset.UTC).toString(); + String description = "Test, comma, description"; + String error = "test-error"; + String nodeAssignment = "test-node"; + MethodComponentContext methodComponentContext = getMethodComponentContext(); + + Exception e = expectThrows( + IllegalArgumentException.class, + () -> new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + nodeAssignment, + methodComponentContext + ) + ); + assertEquals("Model description cannot contain any commas: ','", e.getMessage()); + } + private static MethodComponentContext getMethodComponentContext() { + Map nestedParameters = new HashMap() { + { + put("testNestedKey1", "testNestedString"); + put("testNestedKey2", 1); + } + }; + Map parameters = new HashMap<>() { + { + put("testKey1", "testString"); + put("testKey2", 0); + put("testKey3", new MethodComponentContext("ivf", nestedParameters)); + } + }; + MethodComponentContext methodComponentContext = new MethodComponentContext("hnsw", parameters); + return methodComponentContext; } } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java index 1180dcf0a..480a5a40c 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java @@ -13,6 +13,7 @@ import org.apache.hc.core5.http.io.entity.EntityUtils; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.MediaTypeRegistry; @@ -192,6 +193,66 @@ public void testTrainModel_fail_tooMuchData() throws Exception { assertTrainingFails(modelId, 30, 1000); } + public void testTrainModel_fail_commaInDescription() throws Exception { + // Test checks that training when passing in an id succeeds + + String modelId = "test-model-id"; + String trainingIndexName = "train-index"; + String trainingFieldName = "train-field"; + int dimension = 8; + + // Create a training index and randomly ingest data into it + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + + // Call the train API with this definition: + /* + { + "training_index": "train_index", + "training_field": "train_field", + "dimension": 8, + "description": "this should be allowed to be null", + "method": { + "name":"ivf", + "engine":"faiss", + "space_type": "l2", + "parameters":{ + "nlist":1, + "encoder":{ + "name":"pq", + "parameters":{ + "code_size":2, + "m": 2 + } + } + } + } + } + */ + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .field(KNN_ENGINE, "faiss") + .field(METHOD_PARAMETER_SPACE_TYPE, "l2") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 1) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, "pq") + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) + .field(ENCODER_PARAMETER_PQ_M, 2) + .endObject() + .endObject() + .endObject() + .endObject(); + Map method = xContentBuilderToMap(builder); + + Exception e = expectThrows( + ResponseException.class, + () -> trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, "dummy description, with comma") + ); + assertTrue(e.getMessage().contains("Model description cannot contain any commas: ','")); + } + public void testTrainModel_success_withId() throws Exception { // Test checks that training when passing in an id succeeds