Skip to content

Commit

Permalink
Automated model interface generation on aws llms (opensearch-project#…
Browse files Browse the repository at this point in the history
…2689) (opensearch-project#2707)

* Automated model interface generation on aws llms

Signed-off-by: b4sjoo <[email protected]>

* Add UTs

Signed-off-by: b4sjoo <[email protected]>

* Add Comments and TODOs

Signed-off-by: b4sjoo <[email protected]>

---------

Signed-off-by: b4sjoo <[email protected]>
(cherry picked from commit 9b413a7)

Co-authored-by: Sicheng Song <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and b4sjoo authored Jul 23, 2024
1 parent c310660 commit c84c947
Show file tree
Hide file tree
Showing 7 changed files with 880 additions and 7 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.utils;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Spy;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;

import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.updateRegisterModelInputModelInterfaceFieldsByConnector;

public class ModelInterfaceUtilsTest {
@Spy
MLRegisterModelInput registerModelInputWithInnerConnector;

@Spy
MLRegisterModelInput registerModelInputWithStandaloneConnector;

@Spy
public HttpConnector connector;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

@Before
public void setUp() throws Exception {
registerModelInputWithInnerConnector = MLRegisterModelInput
.builder()
.modelName("test-model-with-inner-connector")
.functionName(FunctionName.REMOTE)
.build();

registerModelInputWithStandaloneConnector = MLRegisterModelInput
.builder()
.modelName("test-model-with-stand-alone-connector")
.functionName(FunctionName.REMOTE)
.build();
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE() {
Map<String, String> parameters = new HashMap<>();
parameters.put("service_name", "bedrock");
parameters.put("model", "ai21.j2-mid-v1");
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE);
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE() {
Map<String, String> parameters = new HashMap<>();
parameters.put("service_name", "bedrock");
parameters.put("model", "anthropic.claude-3-sonnet-20240229-v1:0");
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_ANTHROPIC_CLAUDE_V3_SONNET_MODEL_INTERFACE);
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE() {
Map<String, String> parameters = new HashMap<>();
parameters.put("service_name", "bedrock");
parameters.put("model", "anthropic.claude-v2");
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_ANTHROPIC_CLAUDE_V2_MODEL_INTERFACE);
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE() {
Map<String, String> parameters = new HashMap<>();
parameters.put("service_name", "bedrock");
parameters.put("model", "cohere.embed.english-v3");
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_COHERE_EMBED_ENGLISH_V3_MODEL_INTERFACE);
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE() {
Map<String, String> parameters = new HashMap<>();
parameters.put("service_name", "bedrock");
parameters.put("model", "cohere.embed.multilingual-v3");
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_COHERE_EMBED_MULTILINGUAL_V3_MODEL_INTERFACE);
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE() {
Map<String, String> parameters = new HashMap<>();
parameters.put("service_name", "bedrock");
parameters.put("model", "amazon.titan-embed-text-v1");
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_TEXT_V1_MODEL_INTERFACE);
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE() {
Map<String, String> parameters = new HashMap<>();
parameters.put("service_name", "bedrock");
parameters.put("model", "amazon.titan-embed-image-v1");
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), BEDROCK_TITAN_EMBED_MULTI_MODAL_V1_MODEL_INTERFACE);
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE() {
Map<String, String> parameters = new HashMap<>();
parameters.put("service_name", "comprehend");
parameters.put("api_name", "DetectDominantLanguage");
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), AMAZON_COMPREHEND_DETECTDOMAINANTLANGUAGE_API_INTERFACE);
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE() {
Map<String, String> parameters = new HashMap<>();
parameters.put("service_name", "textract");
parameters.put("api_name", "DetectDocumentText");
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertEquals(registerModelInputWithStandaloneConnector.getModelInterface(), AMAZON_TEXTRACT_DETECTDOCUMENTTEXT_API_INTERFACE);
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorServiceNameNotFound() {
Map<String, String> parameters = new HashMap<>();
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertNull(registerModelInputWithStandaloneConnector.getModelInterface());
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorBedrockModelNameNotFound() {
Map<String, String> parameters = new HashMap<>();
parameters.put("service_name", "bedrock");
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertNull(registerModelInputWithStandaloneConnector.getModelInterface());
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorAmazonComprehendAPINameNotFound() {
Map<String, String> parameters = new HashMap<>();
parameters.put("service_name", "comprehend");
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertNull(registerModelInputWithStandaloneConnector.getModelInterface());
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorNullParameters() {
connector = HttpConnector.builder().protocol("http").build();

updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithStandaloneConnector, connector);
assertNull(registerModelInputWithStandaloneConnector.getModelInterface());
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorInnerConnectorBEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE() {
Map<String, String> parameters = new HashMap<>();
parameters.put("service_name", "bedrock");
parameters.put("model", "ai21.j2-mid-v1");
connector = HttpConnector.builder().protocol("http").parameters(parameters).build();
registerModelInputWithInnerConnector.setConnector(connector);
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithInnerConnector);
assertEquals(registerModelInputWithInnerConnector.getModelInterface(), BEDROCK_AI21_LABS_JURASSIC2_MID_V1_MODEL_INTERFACE);
}

@Test
public void testUpdateRegisterModelInputModelInterfaceFieldsByConnectorInnerConnectorNullParameters() {
connector = HttpConnector.builder().protocol("http").build();
registerModelInputWithInnerConnector.setConnector(connector);
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInputWithInnerConnector);
assertNull(registerModelInputWithInnerConnector.getModelInterface());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.utils.ModelInterfaceUtils.updateRegisterModelInputModelInterfaceFieldsByConnector;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
Expand Down Expand Up @@ -239,7 +240,14 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener<
if (Strings.isNotBlank(registerModelInput.getConnectorId())) {
connectorAccessControlHelper.validateConnectorAccess(client, registerModelInput.getConnectorId(), ActionListener.wrap(r -> {
if (Boolean.TRUE.equals(r)) {
createModelGroup(registerModelInput, listener);
if (registerModelInput.getModelInterface() == null) {
mlModelManager.getConnector(registerModelInput.getConnectorId(), ActionListener.wrap(connector -> {
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInput, connector);
createModelGroup(registerModelInput, listener);
}, listener::onFailure));
} else {
createModelGroup(registerModelInput, listener);
}
} else {
listener
.onFailure(
Expand All @@ -261,6 +269,9 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener<
validateInternalConnector(registerModelInput);
ActionListener<MLCreateConnectorResponse> dryRunResultListener = ActionListener.wrap(res -> {
log.info("Dry run create connector successfully");
if (registerModelInput.getModelInterface() == null) {
updateRegisterModelInputModelInterfaceFieldsByConnector(registerModelInput);
}
createModelGroup(registerModelInput, listener);
}, e -> {
log.error(e.getMessage(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1620,7 +1620,7 @@ public void getController(String modelId, ActionListener<MLController> listener)
* @param connectorId connector id
* @param listener action listener
*/
private void getConnector(String connectorId, ActionListener<Connector> listener) {
public void getConnector(String connectorId, ActionListener<Connector> listener) {
GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId);
client.get(getRequest, ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,11 @@ protected Response registerRemoteModel(String modelGroupName, String name, Strin
+ " \"description\": \"test model\",\n"
+ " \"connector_id\": \""
+ connectorId
+ "\"\n"
+ "\",\n"
+ " \"interface\": {\n"
+ " \"input\": {},\n"
+ " \"output\": {}\n"
+ " }\n"
+ "}";
return TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
Expand Down Expand Up @@ -423,7 +427,11 @@ protected Response registerRemoteModelWithLocalRegexGuardrails(String name, Stri
+ " ],\n"
+ " \"regex\": [\"regex1\", \"regex2\"]\n"
+ " }\n"
+ " }\n"
+ "},\n"
+ " \"interface\": {\n"
+ " \"input\": {},\n"
+ " \"output\": {}\n"
+ " }\n"
+ "}";
return TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
Expand Down Expand Up @@ -461,6 +469,10 @@ protected Response registerRemoteModelWithModelGuardrails(String name, String co
+ " \"connector_id\": \""
+ connectorId
+ "\",\n"
+ " \"interface\": {\n"
+ " \"input\": {},\n"
+ " \"output\": {}\n"
+ " },\n"
+ " \"guardrails\": {\n"
+ " \"type\": \"model\",\n"
+ " \"input_guardrail\": {\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,11 @@ public static Response registerRemoteModel(String modelGroupName, String name, S
+ " \"description\": \"test model\",\n"
+ " \"connector_id\": \""
+ connectorId
+ "\"\n"
+ "\",\n"
+ " \"interface\": {\n"
+ " \"input\": {},\n"
+ " \"output\": {}\n"
+ " }\n"
+ "}";
return TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
Expand Down Expand Up @@ -856,7 +860,11 @@ public static Response registerRemoteModelWithTTLAndSkipHeapMemCheck(String name
+ " \"deploy_setting\": "
+ " { \"model_ttl_minutes\": "
+ ttl
+ "}\n"
+ "},\n"
+ " \"interface\": {\n"
+ " \"input\": {},\n"
+ " \"output\": {}\n"
+ " }\n"
+ "}";
return TestHelper
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ private void setupLLMModel(String connectorId) throws IOException {
+ " \"description\": \"test model\",\n"
+ " \"connector_id\": \""
+ connectorId
+ "\"\n"
+ "\",\n"
+ " \"interface\": {\n"
+ " \"input\": {},\n"
+ " \"output\": {}\n"
+ " }\n"
+ "}";

registerModel(client(), input, response -> {
Expand Down

0 comments on commit c84c947

Please sign in to comment.