Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance: support skip_validating_missing_parameters in connector #2812

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
public class ConnectorUtils {

private static final Aws4Signer signer;
public static final String SKIP_VALIDATE_MISSING_PARAMETERS = "skip_validating_missing_parameters";

static {
signer = Aws4Signer.create();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.escapeRemoteInferenceInputData;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processInput;

Expand Down Expand Up @@ -189,7 +190,9 @@ default void preparePayloadAndInvoke(
// override again to always prioritize the input parameter
parameters.putAll(inputParameters);
String payload = connector.createPayload(action, parameters);
connector.validatePayload(payload);
if (!Boolean.parseBoolean(parameters.getOrDefault(SKIP_VALIDATE_MISSING_PARAMETERS, "false"))) {
connector.validatePayload(payload);
}
String userStr = getClient()
.threadPool()
.getThreadContext()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.algorithms.remote;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.argThat;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD;
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS;

import java.util.Arrays;
import java.util.Map;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.opensearch.client.Client;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorClientConfig;
import org.opensearch.ml.common.connector.RetryBackoffPolicy;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.script.ScriptService;
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableMap;

public class RemoteConnectorExecutorTest {

Encryptor encryptor;

@Mock
Client client;

@Mock
ThreadPool threadPool;

@Mock
private ScriptService scriptService;

@Mock
ActionListener<Tuple<Integer, ModelTensors>> actionListener;

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
when(scriptService.compile(any(), any()))
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"hello world\"}"));
}

private Connector getConnector(Map<String, String> parameters) {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(PREDICT)
.method("POST")
.url("http:///mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
return AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT))
.build();
}

private AwsConnectorExecutor getExecutor(Connector connector) {
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
ThreadContext threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
return executor;
}

@Test
public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDisabled() {
Map<String, String> parameters = ImmutableMap
.of(SKIP_VALIDATE_MISSING_PARAMETERS, "false", SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
Connector connector = getConnector(parameters);
AwsConnectorExecutor executor = getExecutor(connector);

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Map.of("input", "You are a ${parameters.role}"))
.actionType(PREDICT)
.build();
String actionType = inputDataSet.getActionType().toString();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build();

Exception exception = Assert
.assertThrows(
IllegalArgumentException.class,
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener)
);
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
}

@Test
public void executePreparePayloadAndInvoke_SkipValidateMissingParameterEnabled() {
Map<String, String> parameters = ImmutableMap
.of(SKIP_VALIDATE_MISSING_PARAMETERS, "true", SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
Connector connector = getConnector(parameters);
AwsConnectorExecutor executor = getExecutor(connector);

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Map.of("input", "You are a ${parameters.role}"))
.actionType(PREDICT)
.build();
String actionType = inputDataSet.getActionType().toString();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build();

executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener);
Mockito
.verify(executor, times(1))
.invokeRemoteService(any(), any(), any(), argThat(argument -> argument.contains("You are a ${parameters.role}")), any(), any());
}

@Test
public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault() {
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2");
Connector connector = getConnector(parameters);
AwsConnectorExecutor executor = getExecutor(connector);

RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
.builder()
.parameters(Map.of("input", "You are a ${parameters.role}"))
.actionType(PREDICT)
.build();
String actionType = inputDataSet.getActionType().toString();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build();

Exception exception = Assert
.assertThrows(
IllegalArgumentException.class,
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener)
);
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,69 @@ public void testPredictRemoteModelWithWrongOutputInterface() throws IOException,
});
}

public void testPredictRemoteModelWithSkipValidatingMissingParameter(
String testCase,
Consumer<Map> verifyResponse,
Consumer<Exception> verifyException
) throws IOException,
InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
return;
}
Response response = createConnector(this.getConnectorBodyBySkipValidatingMissingParameter(testCase));
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModelWithInterface("openAI-GPT-3.5 completions", connectorId, "correctInterface");
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
response = getTask(taskId);
responseMap = parseResponseToMap(response);
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a ${parameters.test}\"\n" + " }\n" + "}";
try {
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
verifyResponse.accept(responseMap);
} catch (Exception e) {
verifyException.accept(e);
}
}

public void testPredictRemoteModelWithSkipValidatingMissingParameterMissing() throws IOException, InterruptedException {
testPredictRemoteModelWithSkipValidatingMissingParameter("missing", null, (exception) -> {
assertTrue(exception.getMessage().contains("Some parameter placeholder not filled in payload: test"));
});
}

public void testPredictRemoteModelWithSkipValidatingMissingParameterEnabled() throws IOException, InterruptedException {
testPredictRemoteModelWithSkipValidatingMissingParameter("enabled", (responseMap) -> {
List responseList = (List) responseMap.get("inference_results");
responseMap = (Map) responseList.get(0);
responseList = (List) responseMap.get("output");
responseMap = (Map) responseList.get(0);
responseMap = (Map) responseMap.get("dataAsMap");
responseList = (List) responseMap.get("choices");
if (responseList == null) {
assertTrue(checkThrottlingOpenAI(responseMap));
return;
}
responseMap = (Map) responseList.get(0);
assertFalse(((String) responseMap.get("text")).isEmpty());
}, null);
}

public void testPredictRemoteModelWithSkipValidatingMissingParameterDisabled() throws IOException, InterruptedException {
testPredictRemoteModelWithSkipValidatingMissingParameter("disabled", null, (exception) -> {
assertTrue(exception.getMessage().contains("Some parameter placeholder not filled in payload: test"));
});
}

public void testOpenAIChatCompletionModel() throws IOException, InterruptedException {
// Skip test if key is null
if (OPENAI_KEY == null) {
Expand Down Expand Up @@ -870,6 +933,85 @@ public static Response registerRemoteModelWithTTLAndSkipHeapMemCheck(String name
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
}

private String getConnectorBodyBySkipValidatingMissingParameter(String testCase) {
return switch (testCase) {
case "missing" -> completionModelConnectorEntity;
case "enabled" -> "{\n"
+ "\"name\": \"OpenAI Connector\",\n"
+ "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n"
+ "\"version\": 1,\n"
+ "\"client_config\": {\n"
+ " \"max_connection\": 20,\n"
+ " \"connection_timeout\": 50000,\n"
+ " \"read_timeout\": 50000\n"
+ " },\n"
+ "\"protocol\": \"http\",\n"
+ "\"parameters\": {\n"
+ " \"endpoint\": \"api.openai.com\",\n"
+ " \"auth\": \"API_Key\",\n"
+ " \"content_type\": \"application/json\",\n"
+ " \"max_tokens\": 7,\n"
+ " \"temperature\": 0,\n"
+ " \"model\": \"gpt-3.5-turbo-instruct\",\n"
+ " \"skip_validating_missing_parameters\": \"true\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"openAI_key\": \""
+ this.OPENAI_KEY
+ "\"\n"
+ " },\n"
+ " \"actions\": [\n"
+ " {"
+ " \"action_type\": \"predict\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n"
+ " \"headers\": {\n"
+ " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n"
+ " },\n"
+ " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n"
+ " }\n"
+ " ]\n"
+ "}";
case "disabled" -> "{\n"
+ "\"name\": \"OpenAI Connector\",\n"
+ "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n"
+ "\"version\": 1,\n"
+ "\"client_config\": {\n"
+ " \"max_connection\": 20,\n"
+ " \"connection_timeout\": 50000,\n"
+ " \"read_timeout\": 50000\n"
+ " },\n"
+ "\"protocol\": \"http\",\n"
+ "\"parameters\": {\n"
+ " \"endpoint\": \"api.openai.com\",\n"
+ " \"auth\": \"API_Key\",\n"
+ " \"content_type\": \"application/json\",\n"
+ " \"max_tokens\": 7,\n"
+ " \"temperature\": 0,\n"
+ " \"model\": \"gpt-3.5-turbo-instruct\",\n"
+ " \"skip_validating_missing_parameters\": \"false\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"openAI_key\": \""
+ this.OPENAI_KEY
+ "\"\n"
+ " },\n"
+ " \"actions\": [\n"
+ " {"
+ " \"action_type\": \"predict\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \"https://${parameters.endpoint}/v1/completions\",\n"
+ " \"headers\": {\n"
+ " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n"
+ " },\n"
+ " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"prompt\\\": \\\"${parameters.prompt}\\\", \\\"max_tokens\\\": ${parameters.max_tokens}, \\\"temperature\\\": ${parameters.temperature} }\"\n"
+ " }\n"
+ " ]\n"
+ "}";
default -> throw new IllegalArgumentException("Invalid test case");
};
}

public static Response registerRemoteModelWithInterface(String name, String connectorId, String testCase) throws IOException {
String registerModelGroupEntity = "{\n"
+ " \"name\": \"remote_model_group\",\n"
Expand Down
Loading