Skip to content

Commit

Permalink
Block commas in model description (#1692)
Browse files Browse the repository at this point in the history
* Block commas in model description

Signed-off-by: Ryan Bogan <[email protected]>

* Add changelog entry

Signed-off-by: Ryan Bogan <[email protected]>

* Add check in rest handler

Signed-off-by: Ryan Bogan <[email protected]>

* Extract if statement into ModelUtil method

Signed-off-by: Ryan Bogan <[email protected]>

* Remove ingestion from integ test

Signed-off-by: Ryan Bogan <[email protected]>

---------

Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan authored May 8, 2024
1 parent 73d5425 commit 011775a
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
* Block commas in model description [#1692](https://github.com/opensearch-project/k-NN/pull/1692)
### Infrastructure
### Documentation
### Maintenance
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/indices/ModelMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public ModelMetadata(StreamInput in) throws IOException {
// Description and error may be empty. However, reading the string will work as long as they are not null
// which is checked in constructor and setters
this.description = in.readString();
ModelUtil.blockCommasInModelDescription(this.description);
this.error = in.readString();

if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), IndexUtil.MODEL_NODE_ASSIGNMENT_KEY)) {
Expand Down Expand Up @@ -123,6 +124,7 @@ public ModelMetadata(
this.state = new AtomicReference<>(Objects.requireNonNull(modelState, "modelState must not be null"));
this.timestamp = Objects.requireNonNull(timestamp, "timestamp must not be null");
this.description = Objects.requireNonNull(description, "description must not be null");
ModelUtil.blockCommasInModelDescription(this.description);
this.error = Objects.requireNonNull(error, "error must not be null");
this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null");
this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null");
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/opensearch/knn/indices/ModelUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
*/
public class ModelUtil {

public static void blockCommasInModelDescription(String description) {
if (description.contains(",")) {
throw new IllegalArgumentException("Model description cannot contain any commas: ','");
}
}

public static boolean isModelPresent(ModelMetadata modelMetadata) {
return modelMetadata != null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.knn.plugin.transport.TrainingJobRouterAction;
import org.opensearch.knn.plugin.transport.TrainingModelRequest;
Expand Down Expand Up @@ -104,6 +105,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr
searchSize = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
} else if (MODEL_DESCRIPTION.equals(fieldName) && ensureNotSet(fieldName, description)) {
description = parser.textOrNull();
ModelUtil.blockCommasInModelDescription(description);
} else {
throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter.");
}
Expand Down
60 changes: 46 additions & 14 deletions src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -608,20 +608,7 @@ public void testFromResponseMap() throws IOException {
String description = "test-description";
String error = "test-error";
String nodeAssignment = "test-node";
Map<String, Object> nestedParameters = new HashMap<String, Object>() {
{
put("testNestedKey1", "testNestedString");
put("testNestedKey2", 1);
}
};
Map<String, Object> parameters = new HashMap<>() {
{
put("testKey1", "testString");
put("testKey2", 0);
put("testKey3", new MethodComponentContext("ivf", nestedParameters));
}
};
MethodComponentContext methodComponentContext = new MethodComponentContext("hnsw", parameters);
MethodComponentContext methodComponentContext = getMethodComponentContext();
MethodComponentContext emptyMethodComponentContext = MethodComponentContext.EMPTY;

ModelMetadata expected = new ModelMetadata(
Expand Down Expand Up @@ -667,6 +654,51 @@ public void testFromResponseMap() throws IOException {
metadataAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, null);
metadataAsMap.put(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT, null);
assertEquals(expected2, fromMap);
}

public void testBlockCommasInDescription() {
KNNEngine knnEngine = KNNEngine.DEFAULT;
SpaceType spaceType = SpaceType.L2;
int dimension = 128;
ModelState modelState = ModelState.TRAINING;
String timestamp = ZonedDateTime.now(ZoneOffset.UTC).toString();
String description = "Test, comma, description";
String error = "test-error";
String nodeAssignment = "test-node";
MethodComponentContext methodComponentContext = getMethodComponentContext();

Exception e = expectThrows(
IllegalArgumentException.class,
() -> new ModelMetadata(
knnEngine,
spaceType,
dimension,
modelState,
timestamp,
description,
error,
nodeAssignment,
methodComponentContext
)
);
assertEquals("Model description cannot contain any commas: ','", e.getMessage());
}

private static MethodComponentContext getMethodComponentContext() {
Map<String, Object> nestedParameters = new HashMap<String, Object>() {
{
put("testNestedKey1", "testNestedString");
put("testNestedKey2", 1);
}
};
Map<String, Object> parameters = new HashMap<>() {
{
put("testKey1", "testString");
put("testKey2", 0);
put("testKey3", new MethodComponentContext("ivf", nestedParameters));
}
};
MethodComponentContext methodComponentContext = new MethodComponentContext("hnsw", parameters);
return methodComponentContext;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.MediaTypeRegistry;
Expand Down Expand Up @@ -192,6 +193,66 @@ public void testTrainModel_fail_tooMuchData() throws Exception {
assertTrainingFails(modelId, 30, 1000);
}

public void testTrainModel_fail_commaInDescription() throws Exception {
// Test checks that training when passing in an id succeeds

String modelId = "test-model-id";
String trainingIndexName = "train-index";
String trainingFieldName = "train-field";
int dimension = 8;

// Create a training index and randomly ingest data into it
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);

// Call the train API with this definition:
/*
{
"training_index": "train_index",
"training_field": "train_field",
"dimension": 8,
"description": "this should be allowed to be null",
"method": {
"name":"ivf",
"engine":"faiss",
"space_type": "l2",
"parameters":{
"nlist":1,
"encoder":{
"name":"pq",
"parameters":{
"code_size":2,
"m": 2
}
}
}
}
}
*/
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, "ivf")
.field(KNN_ENGINE, "faiss")
.field(METHOD_PARAMETER_SPACE_TYPE, "l2")
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_NLIST, 1)
.startObject(METHOD_ENCODER_PARAMETER)
.field(NAME, "pq")
.startObject(PARAMETERS)
.field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2)
.field(ENCODER_PARAMETER_PQ_M, 2)
.endObject()
.endObject()
.endObject()
.endObject();
Map<String, Object> method = xContentBuilderToMap(builder);

Exception e = expectThrows(
ResponseException.class,
() -> trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, "dummy description, with comma")
);
assertTrue(e.getMessage().contains("Model description cannot contain any commas: ','"));
}

public void testTrainModel_success_withId() throws Exception {
// Test checks that training when passing in an id succeeds

Expand Down

0 comments on commit 011775a

Please sign in to comment.