Skip to content

Commit

Permalink
fix error of ML inference processor in foreach processor (opensearch-…
Browse files Browse the repository at this point in the history
…project#2474)

* fix error of ML inference processor in foreach processor

Signed-off-by: Yaliang Wu <[email protected]>

* add IT

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored May 28, 2024
1 parent 0722df1 commit 2c11e7f
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,15 @@ private void appendFieldValue(

modelOutputValue = getModelOutputValue(modelTensorOutput, modelOutputFieldName, ignoreMissing);

List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocument.getSourceAndMetadata(), newDocumentFieldName);
Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);

if (dotPathsInArray.size() == 1) {
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
TemplateScript.Factory ingestField = ConfigurationUtils
.compileTemplate(TYPE, tag, newDocumentFieldName, newDocumentFieldName, scriptService);
.compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService);
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
} else {
if (!(modelOutputValue instanceof List)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
import static org.opensearch.ml.processor.MLInferenceIngestProcessor.DEFAULT_OUTPUT_FIELD_NAME;

import java.nio.ByteBuffer;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;

import org.junit.Assert;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
Expand Down Expand Up @@ -1043,6 +1045,43 @@ public void testParseGetDataInTensor_BooleanDataType() {
assertEquals(List.of(true, false, true), result);
}

public void testWriteNewDotPathForNestedObject() {
Map<String, Object> docSourceAndMetaData = new HashMap<>();
docSourceAndMetaData.put("_id", randomAlphaOfLength(5));
docSourceAndMetaData.put("_index", "my_books");

List<Map<String, String>> books = new ArrayList<>();
Map<String, String> book1 = new HashMap<>();
book1.put("title", "first book");
book1.put("description", "this is first book");
Map<String, String> book2 = new HashMap<>();
book2.put("title", "second book");
book2.put("description", "this is second book");
books.add(book1);
books.add(book2);
docSourceAndMetaData.put("books", books);

Map<String, Object> ingestMetadata = new HashMap<>();
ingestMetadata.put("pipeline", "test_pipeline");
ingestMetadata.put("timeestamp", ZonedDateTime.now());
Map<String, String> ingestValue = new HashMap<>();
ingestValue.put("title", "first book");
ingestValue.put("description", "this is first book");
ingestMetadata.put("_value", ingestValue);
docSourceAndMetaData.put("_ingest", ingestMetadata);

String path = "_ingest._value.title";
List<String> newPath = modelExecutor.writeNewDotPathForNestedObject(docSourceAndMetaData, path);
Assert.assertEquals(1, newPath.size());
Assert.assertEquals(path, newPath.get(0));

String path2 = "books.*.title";
List<String> newPath2 = modelExecutor.writeNewDotPathForNestedObject(docSourceAndMetaData, path2);
Assert.assertEquals(2, newPath2.size());
Assert.assertEquals("books.0.title", newPath2.get(0));
Assert.assertEquals("books.1.title", newPath2.get(1));
}

private static Map<String, Object> getNestedObjectWithAnotherNestedObjectSource() {
ArrayList<Object> childDocuments = new ArrayList<>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -960,4 +960,25 @@ public void waitForTask(String taskId, MLTaskState targetState) throws Interrupt
}, CUSTOM_MODEL_TIMEOUT, TimeUnit.SECONDS);
assertTrue(taskDone.get());
}

public String registerRemoteModel(String createConnectorInput, String modelName, boolean deploy) throws IOException,
InterruptedException {
Response response = RestMLRemoteInferenceIT.createConnector(createConnectorInput);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = RestMLRemoteInferenceIT.registerRemoteModel(modelName, modelName, connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
if (deploy) {
response = RestMLRemoteInferenceIT.deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
}
return modelId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
import org.junit.Before;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.utils.TestHelper;

import com.google.common.collect.ImmutableList;
import com.jayway.jsonpath.JsonPath;

public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase {
private final String OPENAI_KEY = System.getenv("OPENAI_KEY");
private String modelId;
private String openAIChatModelId;
private String bedrockEmbeddingModelId;
private final String completionModelConnectorEntity = "{\n"
+ " \"name\": \"OpenAI text embedding model Connector\",\n"
+ " \"description\": \"The connector to public OpenAI text embedding model service\",\n"
Expand All @@ -52,26 +52,58 @@ public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase {
+ " ]\n"
+ "}";

private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID");
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN");
private static final String GITHUB_CI_AWS_REGION = "us-west-2";

private final String bedrockEmbeddingModelConnectorEntity = "{\n"
+ " \"name\": \"Amazon Bedrock Connector: embedding\",\n"
+ " \"description\": \"The connector to bedrock Titan embedding model\",\n"
+ " \"version\": 1,\n"
+ " \"protocol\": \"aws_sigv4\",\n"
+ " \"parameters\": {\n"
+ " \"region\": \""
+ GITHUB_CI_AWS_REGION
+ "\",\n"
+ " \"service_name\": \"bedrock\",\n"
+ " \"model_name\": \"amazon.titan-embed-text-v1\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"access_key\": \""
+ AWS_ACCESS_KEY_ID
+ "\",\n"
+ " \"secret_key\": \""
+ AWS_SECRET_ACCESS_KEY
+ "\",\n"
+ " \"session_token\": \""
+ AWS_SESSION_TOKEN
+ "\"\n"
+ " },\n"
+ " \"actions\": [\n"
+ " {\n"
+ " \"action_type\": \"predict\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke\",\n"
+ " \"headers\": {\n"
+ " \"content-type\": \"application/json\",\n"
+ " \"x-amz-content-sha256\": \"required\"\n"
+ " },\n"
+ " \"request_body\": \"{ \\\"inputText\\\": \\\"${parameters.input}\\\" }\",\n"
+ " \"pre_process_function\": \"connector.pre_process.bedrock.embedding\",\n"
+ " \"post_process_function\": \"connector.post_process.bedrock.embedding\"\n"
+ " }\n"
+ " ]\n"
+ "}";

@Before
public void setup() throws IOException, InterruptedException {
RestMLRemoteInferenceIT.disableClusterConnectorAccessControl();
Thread.sleep(20000);

// create connectors for OPEN AI and register model
Response response = RestMLRemoteInferenceIT.createConnector(completionModelConnectorEntity);
Map responseMap = parseResponseToMap(response);
String openAIConnectorId = (String) responseMap.get("connector_id");
response = RestMLRemoteInferenceIT.registerRemoteModel("openAI-GPT-3.5 chat model", openAIConnectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = RestMLRemoteInferenceIT.getTask(taskId);
responseMap = parseResponseToMap(response);
this.modelId = (String) responseMap.get("model_id");
response = RestMLRemoteInferenceIT.deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String openAIChatModelName = "openAI-GPT-3.5 chat model " + randomAlphaOfLength(5);
this.openAIChatModelId = registerRemoteModel(completionModelConnectorEntity, openAIChatModelName, true);
String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true);
}

public void testMLInferenceProcessorWithObjectFieldType() throws Exception {
Expand All @@ -82,7 +114,7 @@ public void testMLInferenceProcessorWithObjectFieldType() throws Exception {
+ " {\n"
+ " \"ml_inference\": {\n"
+ " \"model_id\": \""
+ this.modelId
+ this.openAIChatModelId
+ "\",\n"
+ " \"input_map\": [\n"
+ " {\n"
Expand Down Expand Up @@ -141,7 +173,7 @@ public void testMLInferenceProcessorWithNestedFieldType() throws Exception {
+ " {\n"
+ " \"ml_inference\": {\n"
+ " \"model_id\": \""
+ this.modelId
+ this.openAIChatModelId
+ "\",\n"
+ " \"input_map\": [\n"
+ " {\n"
Expand Down Expand Up @@ -228,6 +260,96 @@ public void testMLInferenceProcessorWithNestedFieldType() throws Exception {
Assert.assertEquals(0.014352738, (Double) embedding4.get(0), 0.005);
}

public void testMLInferenceProcessorWithForEachProcessor() throws Exception {
String indexName = "my_books";
String pipelineName = "my_books_bedrock_embedding_pipeline";
String createIndexRequestBody = "{\n"
+ " \"settings\": {\n"
+ " \"index\": {\n"
+ " \"default_pipeline\": \""
+ pipelineName
+ "\"\n"
+ " }\n"
+ " },\n"
+ " \"mappings\": {\n"
+ " \"properties\": {\n"
+ " \"books\": {\n"
+ " \"type\": \"nested\",\n"
+ " \"properties\": {\n"
+ " \"title_embedding\": {\n"
+ " \"type\": \"float\"\n"
+ " },\n"
+ " \"title\": {\n"
+ " \"type\": \"text\"\n"
+ " },\n"
+ " \"description\": {\n"
+ " \"type\": \"text\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}";
createIndex(indexName, createIndexRequestBody);

String createPipelineRequestBody = "{\n"
+ " \"description\": \"Test bedrock embeddings\",\n"
+ " \"processors\": [\n"
+ " {\n"
+ " \"foreach\": {\n"
+ " \"field\": \"books\",\n"
+ " \"processor\": {\n"
+ " \"ml_inference\": {\n"
+ " \"model_id\": \""
+ this.bedrockEmbeddingModelId
+ "\",\n"
+ " \"input_map\": [\n"
+ " {\n"
+ " \"input\": \"_ingest._value.title\"\n"
+ " }\n"
+ " ],\n"
+ " \"output_map\": [\n"
+ " {\n"
+ " \"_ingest._value.title_embedding\": \"$.embedding\"\n"
+ " }\n"
+ " ],\n"
+ " \"ignore_missing\": false,\n"
+ " \"ignore_failure\": false\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " ]\n"
+ "}";
createPipelineProcessor(createPipelineRequestBody, pipelineName);

// Skip test if key is null
if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
return;
}
String uploadDocumentRequestBody = "{\n"
+ " \"books\": [{\n"
+ " \"title\": \"first book\",\n"
+ " \"description\": \"This is first book\"\n"
+ " },\n"
+ " {\n"
+ " \"title\": \"second book\",\n"
+ " \"description\": \"This is second book\"\n"
+ " }\n"
+ " ]\n"
+ "}";
uploadDocument(indexName, "1", uploadDocumentRequestBody);
Map document = getDocument(indexName, "1");

List embeddingList = JsonPath.parse(document).read("_source.books[*].title_embedding");
Assert.assertEquals(2, embeddingList.size());

List embedding1 = JsonPath.parse(document).read("_source.books[0].title_embedding");
Assert.assertEquals(1536, embedding1.size());
List embedding2 = JsonPath.parse(document).read("_source.books[1].title_embedding");
Assert.assertEquals(1536, embedding2.size());
}

protected void createPipelineProcessor(String requestBody, final String pipelineName) throws Exception {
Response pipelineCreateResponse = TestHelper
.makeRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -779,8 +779,14 @@ public static Response createConnector(String input) throws IOException {
}

public static Response registerRemoteModel(String name, String connectorId) throws IOException {
return registerRemoteModel("remote_model_group", name, connectorId);
}

public static Response registerRemoteModel(String modelGroupName, String name, String connectorId) throws IOException {
String registerModelGroupEntity = "{\n"
+ " \"name\": \"remote_model_group\",\n"
+ " \"name\": \""
+ modelGroupName
+ "\",\n"
+ " \"description\": \"This is an example description\"\n"
+ "}";
Response response = TestHelper
Expand Down

0 comments on commit 2c11e7f

Please sign in to comment.