Skip to content

Commit

Permalink
Enhance: support skip_validating_missing_parameters in connector (ope…
Browse files Browse the repository at this point in the history
…nsearch-project#2812)

* introduce skip parameter validation

Signed-off-by: yuye-aws <[email protected]>

* implement ut

Signed-off-by: yuye-aws <[email protected]>

* implement it

Signed-off-by: yuye-aws <[email protected]>

* spotless apply

Signed-off-by: yuye-aws <[email protected]>

---------

Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws authored Aug 12, 2024
1 parent a4dff63 commit 9663053
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
public class ConnectorUtils {

private static final Aws4Signer signer;
public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters";

static {
signer = Aws4Signer.create();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;

Expand Down Expand Up @@ -189,7 +190,9 @@ default void preparePayloadAndInvoke(
// override again to always prioritize the input parameter
parameters.putAll(inputParameters);
String payload = connector.createPayload(action, parameters);
connector.validatePayload(payload);
if (!Boolean.parseBoolean(parameters.getOrDefault(SKIP_VALIDATE_MISSING_PARAMETERS, "false"))) {
connector.validatePayload(payload);
}
String userStr = getClient()
.threadPool()
.getThreadContext()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.algorithms.remote;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.argThat;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;

import java.util.Arrays;
import java.util.Map;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.client.Client;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorClientConfig;
import org.opensearch.ml.common.connector.RetryBackoffPolicy;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.script.ScriptService;
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableMap;

public class RemoteConnectorExecutorTest {

Encryptor encryptor;

@Mock
Client client;

@Mock
ThreadPool threadPool;

@Mock
private ScriptService scriptService;

@Mock
ActionListener<Tuple<Integer, ModelTensors>> actionListener;

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
when(scriptService.compile(any(), any()))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"hello world\"}"));
}

private Connector getConnector(Map<String, String> parameters) {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(PREDICT)
.method("POST")
.url("http:///mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
return AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT))
.build();
}

private AwsConnectorExecutor getExecutor(Connector connector) {
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
ThreadContext threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
return executor;
}

@Test
public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDisabled() {
Map<String, String> parameters = ImmutableMap
.of(SKIP_VALIDATE_MISSING_PARAMETERS, "false", SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
Connector connector = getConnector(parameters);
AwsConnectorExecutor executor = getExecutor(connector);

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Map.of("input", "You are a ${parameters.role}"))
.actionType(PREDICT)
.build();
String actionType = inputDataSet.getActionType().toString();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build();

Exception exception = Assert
.assertThrows(
IllegalArgumentException.class,
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener)
);
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
}

@Test
public void executePreparePayloadAndInvoke_SkipValidateMissingParameterEnabled() {
Map<String, String> parameters = ImmutableMap
.of(SKIP_VALIDATE_MISSING_PARAMETERS, "true", SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
Connector connector = getConnector(parameters);
AwsConnectorExecutor executor = getExecutor(connector);

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Map.of("input", "You are a ${parameters.role}"))
.actionType(PREDICT)
.build();
String actionType = inputDataSet.getActionType().toString();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build();

executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener);
Mockito
.verify(executor, times(1))
.invokeRemoteService(any(), any(), any(), argThat(argument -> argument.contains("You are a ${parameters.role}")), any(), any());
}

@Test
public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault() {
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
Connector connector = getConnector(parameters);
AwsConnectorExecutor executor = getExecutor(connector);

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Map.of("input", "You are a ${parameters.role}"))
.actionType(PREDICT)
.build();
String actionType = inputDataSet.getActionType().toString();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build();

Exception exception = Assert
.assertThrows(
IllegalArgumentException.class,
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener)
);
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,69 @@ public void testPredictRemoteModelWithWrongOutputInterface() throws IOException,
});
}

public void testPredictRemoteModelWithSkipValidatingMissingParameter(
String testCase,
Consumer<Map> verifyResponse,
Consumer<Exception> verifyException
) throws IOException,
InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
return;
}
Response response = createConnector(this.getConnectorBodyBySkipValidatingMissingParameter(testCase));
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, "correctInterface");
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a ${parameters.test}\"\n" + " }\n" + "}";
try {
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
verifyResponse.accept(responseMap);
} catch (Exception e) {
verifyException.accept(e);
}
}

public void testPredictRemoteModelWithSkipValidatingMissingParameterMissing() throws IOException, InterruptedException {
testPredictRemoteModelWithSkipValidatingMissingParameter("missing", null, (exception) -> {
assertTrue(exception.getMessage().contains("Some parameter placeholder not filled in payload: test"));
});
}

public void testPredictRemoteModelWithSkipValidatingMissingParameterEnabled() throws IOException, InterruptedException {
testPredictRemoteModelWithSkipValidatingMissingParameter("enabled", (responseMap) -> {
List responseList = (List) responseMap.get("inference_results");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("output");
responseMap = (Map) responseList.get(0);
responseMap = (Map) responseMap.get("dataAsMap");
responseList = (List) responseMap.get("choices");
if (responseList == null) {
assertTrue(checkThrottlingOpenAI(responseMap));
return;
}
responseMap = (Map) responseList.get(0);
assertFalse(((String) responseMap.get("text")).isEmpty());
}, null);
}

public void testPredictRemoteModelWithSkipValidatingMissingParameterDisabled() throws IOException, InterruptedException {
testPredictRemoteModelWithSkipValidatingMissingParameter("disabled", null, (exception) -> {
assertTrue(exception.getMessage().contains("Some parameter placeholder not filled in payload: test"));
});
}

public void testOpenAIChatCompletionModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down Expand Up @@ -870,6 +933,85 @@ public static Response registerRemoteModelWithTTLAndSkipHeapMemCheck(String name
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
}

private String getConnectorBodyBySkipValidatingMissingParameter(String testCase) {
return switch (testCase) {
case "missing" -> completionModelConnectorEntity;
case "enabled" -> "{\n"
+ "\"name\": \"OpenAI Connector\",\n"
+ "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n"
+ "\"version\": 1,\n"
+ "\"client_config\": {\n"
+ " \"max_connection\": 20,\n"
+ " \"connection_timeout\": 50000,\n"
+ " \"read_timeout\": 50000\n"
+ " },\n"
+ "\"protocol\": \"http\",\n"
+ "\"parameters\": {\n"
+ " \"endpoint\": \"api.openai.com\",\n"
+ " \"auth\": \"API_Key\",\n"
+ " \"content_type\": \"application/json\",\n"
+ " \"max_tokens\": 7,\n"
+ " \"temperature\": 0,\n"
+ " \"model\": \"gpt-3.5-turbo-instruct\",\n"
+ " \"skip_validating_missing_parameters\": \"true\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"openAI_key\": \""
+ this.OPENAI_KEY
+ "\"\n"
+ " },\n"
+ " \"actions\": [\n"
+ " {"
+ " \"action_type\": \"predict\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n"
+ " \"headers\": {\n"
+ " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n"
+ " },\n"
+ " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n"
+ " }\n"
+ " ]\n"
+ "}";
case "disabled" -> "{\n"
+ "\"name\": \"OpenAI Connector\",\n"
+ "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n"
+ "\"version\": 1,\n"
+ "\"client_config\": {\n"
+ " \"max_connection\": 20,\n"
+ " \"connection_timeout\": 50000,\n"
+ " \"read_timeout\": 50000\n"
+ " },\n"
+ "\"protocol\": \"http\",\n"
+ "\"parameters\": {\n"
+ " \"endpoint\": \"api.openai.com\",\n"
+ " \"auth\": \"API_Key\",\n"
+ " \"content_type\": \"application/json\",\n"
+ " \"max_tokens\": 7,\n"
+ " \"temperature\": 0,\n"
+ " \"model\": \"gpt-3.5-turbo-instruct\",\n"
+ " \"skip_validating_missing_parameters\": \"false\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"openAI_key\": \""
+ this.OPENAI_KEY
+ "\"\n"
+ " },\n"
+ " \"actions\": [\n"
+ " {"
+ " \"action_type\": \"predict\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n"
+ " \"headers\": {\n"
+ " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n"
+ " },\n"
+ " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n"
+ " }\n"
+ " ]\n"
+ "}";
default -> throw new IllegalArgumentException("Invalid test case");
};
}

public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException {
String registerModelGroupEntity = "{\n"
+ " \"name\": \"remote_model_group\",\n"
Expand Down

0 comments on commit 9663053

Please sign in to comment.