From 6194f1dfed95185068fd708cafbfb96e47696b85 Mon Sep 17 00:00:00 2001 From: yuye-aws Date: Mon, 12 Aug 2024 09:40:52 +0800 Subject: [PATCH] Enhance: support skip_validating_missing_parameters in connector (#2812) * introduce skip parameter validation Signed-off-by: yuye-aws * implement ut Signed-off-by: yuye-aws * implement it Signed-off-by: yuye-aws * spotless apply Signed-off-by: yuye-aws --------- Signed-off-by: yuye-aws (cherry picked from commit 9663053d544b59ef238ae36bf49b774543552df5) --- .../algorithms/remote/ConnectorUtils.java | 2 + .../remote/RemoteConnectorExecutor.java | 5 +- .../remote/RemoteConnectorExecutorTest.java | 172 ++++++++++++++++++ .../ml/rest/RestMLRemoteInferenceIT.java | 142 +++++++++++++++ 4 files changed, 320 insertions(+), 1 deletion(-) create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 0adfb99663..ef4f25c79a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -58,6 +58,8 @@ public class ConnectorUtils { private static final Aws4Signer signer; + public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters"; + static { signer = Aws4Signer.create(); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index a126153ba5..48adf6f0c4 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData; import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput; @@ -189,7 +190,9 @@ default void preparePayloadAndInvoke( // override again to always prioritize the input parameter parameters.putAll(inputParameters); String payload = connector.createPayload(action, parameters); - connector.validatePayload(payload); + if (!Boolean.parseBoolean(parameters.getOrDefault(SKIP_VALIDATE_MISSING_PARAMETERS, "false"))) { + connector.validatePayload(payload); + } String userStr = getClient() .threadPool() .getThreadContext() diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java new file mode 100644 index 0000000000..a5cfacdb21 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java @@ -0,0 +1,172 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.remote; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.argThat; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; +import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; +import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; +import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; +import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS; + +import java.util.Arrays; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.Client; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ingest.TestTemplateService; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.connector.AwsConnector; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.ConnectorClientConfig; +import org.opensearch.ml.common.connector.RetryBackoffPolicy; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.script.ScriptService; +import org.opensearch.threadpool.ThreadPool; + +import com.google.common.collect.ImmutableMap; + +public class RemoteConnectorExecutorTest { + + Encryptor encryptor; + + @Mock + Client client; + + @Mock + ThreadPool threadPool; + + @Mock + private ScriptService scriptService; + + @Mock + ActionListener> actionListener; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + when(scriptService.compile(any(), any())) + .then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"hello world\"}")); + } + + private Connector getConnector(Map parameters) { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(PREDICT) + .method("POST") + .url("http:///mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + return AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) + .build(); + } + + private AwsConnectorExecutor getExecutor(Connector connector) { + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); + Settings settings = Settings.builder().build(); + ThreadContext threadContext = new ThreadContext(settings); + when(executor.getClient()).thenReturn(client); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + return executor; + } + + @Test + public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDisabled() { + Map parameters = ImmutableMap + .of(SKIP_VALIDATE_MISSING_PARAMETERS, "false", SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2"); + Connector connector = getConnector(parameters); + AwsConnectorExecutor executor = getExecutor(connector); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("input", "You are a ${parameters.role}")) + .actionType(PREDICT) + .build(); + String actionType = inputDataSet.getActionType().toString(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(); + + Exception exception = Assert + .assertThrows( + IllegalArgumentException.class, + () -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener) + ); + assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role"); + } + + @Test + public void executePreparePayloadAndInvoke_SkipValidateMissingParameterEnabled() { + Map parameters = ImmutableMap + .of(SKIP_VALIDATE_MISSING_PARAMETERS, "true", SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2"); + Connector connector = getConnector(parameters); + AwsConnectorExecutor executor = getExecutor(connector); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("input", "You are a ${parameters.role}")) + .actionType(PREDICT) + .build(); + String actionType = inputDataSet.getActionType().toString(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(); + + executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener); + Mockito + .verify(executor, times(1)) + .invokeRemoteService(any(), any(), any(), argThat(argument -> argument.contains("You are a ${parameters.role}")), any(), any()); + } + + @Test + public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault() { + Map parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2"); + Connector connector = getConnector(parameters); + AwsConnectorExecutor executor = getExecutor(connector); + + RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet + .builder() + .parameters(Map.of("input", "You are a ${parameters.role}")) + .actionType(PREDICT) + .build(); + String actionType = inputDataSet.getActionType().toString(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(); + + Exception exception = Assert + .assertThrows( + IllegalArgumentException.class, + () -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener) + ); + assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role"); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 43852e5c36..75319b99e9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -287,6 +287,69 @@ public void testPredictRemoteModelWithWrongOutputInterface() throws IOException, }); } + public void testPredictRemoteModelWithSkipValidatingMissingParameter( + String testCase, + Consumer verifyResponse, + Consumer verifyException + ) throws IOException, + InterruptedException { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + Response response = createConnector(this.getConnectorBodyBySkipValidatingMissingParameter(testCase)); + Map responseMap = parseResponseToMap(response); + String connectorId = (String) responseMap.get("connector_id"); + response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, "correctInterface"); + responseMap = parseResponseToMap(response); + String taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + response = getTask(taskId); + responseMap = parseResponseToMap(response); + String modelId = (String) responseMap.get("model_id"); + response = deployRemoteModel(modelId); + responseMap = parseResponseToMap(response); + taskId = (String) responseMap.get("task_id"); + waitForTask(taskId, MLTaskState.COMPLETED); + String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a ${parameters.test}\"\n" + " }\n" + "}"; + try { + response = predictRemoteModel(modelId, predictInput); + responseMap = parseResponseToMap(response); + verifyResponse.accept(responseMap); + } catch (Exception e) { + verifyException.accept(e); + } + } + + public void testPredictRemoteModelWithSkipValidatingMissingParameterMissing() throws IOException, InterruptedException { + testPredictRemoteModelWithSkipValidatingMissingParameter("missing", null, (exception) -> { + assertTrue(exception.getMessage().contains("Some parameter placeholder not filled in payload: test")); + }); + } + + public void testPredictRemoteModelWithSkipValidatingMissingParameterEnabled() throws IOException, InterruptedException { + testPredictRemoteModelWithSkipValidatingMissingParameter("enabled", (responseMap) -> { + List responseList = (List) responseMap.get("inference_results"); + responseMap = (Map) responseList.get(0); + responseList = (List) responseMap.get("output"); + responseMap = (Map) responseList.get(0); + responseMap = (Map) responseMap.get("dataAsMap"); + responseList = (List) responseMap.get("choices"); + if (responseList == null) { + assertTrue(checkThrottlingOpenAI(responseMap)); + return; + } + responseMap = (Map) responseList.get(0); + assertFalse(((String) responseMap.get("text")).isEmpty()); + }, null); + } + + public void testPredictRemoteModelWithSkipValidatingMissingParameterDisabled() throws IOException, InterruptedException { + testPredictRemoteModelWithSkipValidatingMissingParameter("disabled", null, (exception) -> { + assertTrue(exception.getMessage().contains("Some parameter placeholder not filled in payload: test")); + }); + } + public void testOpenAIChatCompletionModel() throws IOException, InterruptedException { // Skip test if key is null if (OPENAI_KEY == null) { @@ -870,6 +933,85 @@ public static Response registerRemoteModelWithTTLAndSkipHeapMemCheck(String name .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null); } + private String getConnectorBodyBySkipValidatingMissingParameter(String testCase) { + return switch (testCase) { + case "missing" -> completionModelConnectorEntity; + case "enabled" -> "{\n" + + "\"name\": \"OpenAI Connector\",\n" + + "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + "\"version\": 1,\n" + + "\"client_config\": {\n" + + " \"max_connection\": 20,\n" + + " \"connection_timeout\": 50000,\n" + + " \"read_timeout\": 50000\n" + + " },\n" + + "\"protocol\": \"http\",\n" + + "\"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": 7,\n" + + " \"temperature\": 0,\n" + + " \"model\": \"gpt-3.5-turbo-instruct\",\n" + + " \"skip_validating_missing_parameters\": \"true\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + this.OPENAI_KEY + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n" + + " }\n" + + " ]\n" + + "}"; + case "disabled" -> "{\n" + + "\"name\": \"OpenAI Connector\",\n" + + "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + + "\"version\": 1,\n" + + "\"client_config\": {\n" + + " \"max_connection\": 20,\n" + + " \"connection_timeout\": 50000,\n" + + " \"read_timeout\": 50000\n" + + " },\n" + + "\"protocol\": \"http\",\n" + + "\"parameters\": {\n" + + " \"endpoint\": \"api.openai.com\",\n" + + " \"auth\": \"API_Key\",\n" + + " \"content_type\": \"application/json\",\n" + + " \"max_tokens\": 7,\n" + + " \"temperature\": 0,\n" + + " \"model\": \"gpt-3.5-turbo-instruct\",\n" + + " \"skip_validating_missing_parameters\": \"false\"\n" + + " },\n" + + " \"credential\": {\n" + + " \"openAI_key\": \"" + + this.OPENAI_KEY + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n" + + " \"headers\": {\n" + + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + + " },\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n" + + " }\n" + + " ]\n" + + "}"; + default -> throw new IllegalArgumentException("Invalid test case"); + }; + } + public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException { String registerModelGroupEntity = "{\n" + " \"name\": \"remote_model_group\",\n"