forked from opensearch-project/ml-commons
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enhance: support skip_validating_missing_parameters in connector (ope…
…nsearch-project#2812) * introduce skip parameter validation Signed-off-by: yuye-aws <[email protected]> * implement ut Signed-off-by: yuye-aws <[email protected]> * implement it Signed-off-by: yuye-aws <[email protected]> * spotless apply Signed-off-by: yuye-aws <[email protected]> --------- Signed-off-by: yuye-aws <[email protected]>
- Loading branch information
Showing
4 changed files
with
320 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
172 changes: 172 additions & 0 deletions
172
...src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.ml.engine.algorithms.remote; | ||
|
||
import static org.mockito.ArgumentMatchers.any; | ||
import static org.mockito.Mockito.argThat; | ||
import static org.mockito.Mockito.spy; | ||
import static org.mockito.Mockito.times; | ||
import static org.mockito.Mockito.when; | ||
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; | ||
import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; | ||
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; | ||
import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; | ||
import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; | ||
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.SKIP_VALIDATE_MISSING_PARAMETERS; | ||
|
||
import java.util.Arrays; | ||
import java.util.Map; | ||
|
||
import org.junit.Assert; | ||
import org.junit.Before; | ||
import org.junit.Test; | ||
import org.mockito.Mock; | ||
import org.mockito.Mockito; | ||
import org.mockito.MockitoAnnotations; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.common.collect.Tuple; | ||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.common.util.concurrent.ThreadContext; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.ingest.TestTemplateService; | ||
import org.opensearch.ml.common.FunctionName; | ||
import org.opensearch.ml.common.connector.AwsConnector; | ||
import org.opensearch.ml.common.connector.Connector; | ||
import org.opensearch.ml.common.connector.ConnectorAction; | ||
import org.opensearch.ml.common.connector.ConnectorClientConfig; | ||
import org.opensearch.ml.common.connector.RetryBackoffPolicy; | ||
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
import org.opensearch.ml.common.input.MLInput; | ||
import org.opensearch.ml.common.output.model.ModelTensors; | ||
import org.opensearch.ml.engine.encryptor.Encryptor; | ||
import org.opensearch.ml.engine.encryptor.EncryptorImpl; | ||
import org.opensearch.script.ScriptService; | ||
import org.opensearch.threadpool.ThreadPool; | ||
|
||
import com.google.common.collect.ImmutableMap; | ||
|
||
public class RemoteConnectorExecutorTest { | ||
|
||
Encryptor encryptor; | ||
|
||
@Mock | ||
Client client; | ||
|
||
@Mock | ||
ThreadPool threadPool; | ||
|
||
@Mock | ||
private ScriptService scriptService; | ||
|
||
@Mock | ||
ActionListener<Tuple<Integer, ModelTensors>> actionListener; | ||
|
||
@Before | ||
public void setUp() { | ||
MockitoAnnotations.openMocks(this); | ||
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); | ||
when(scriptService.compile(any(), any())) | ||
.then(invocation -> new TestTemplateService.MockTemplateScript.Factory("{\"result\": \"hello world\"}")); | ||
} | ||
|
||
private Connector getConnector(Map<String, String> parameters) { | ||
ConnectorAction predictAction = ConnectorAction | ||
.builder() | ||
.actionType(PREDICT) | ||
.method("POST") | ||
.url("http:///mock") | ||
.requestBody("{\"input\": \"${parameters.input}\"}") | ||
.build(); | ||
Map<String, String> credential = ImmutableMap | ||
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); | ||
return AwsConnector | ||
.awsConnectorBuilder() | ||
.name("test connector") | ||
.version("1") | ||
.protocol("http") | ||
.parameters(parameters) | ||
.credential(credential) | ||
.actions(Arrays.asList(predictAction)) | ||
.connectorClientConfig(new ConnectorClientConfig(10, 10, 10, 1, 1, 0, RetryBackoffPolicy.CONSTANT)) | ||
.build(); | ||
} | ||
|
||
private AwsConnectorExecutor getExecutor(Connector connector) { | ||
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); | ||
Settings settings = Settings.builder().build(); | ||
ThreadContext threadContext = new ThreadContext(settings); | ||
when(executor.getClient()).thenReturn(client); | ||
when(client.threadPool()).thenReturn(threadPool); | ||
when(threadPool.getThreadContext()).thenReturn(threadContext); | ||
return executor; | ||
} | ||
|
||
@Test | ||
public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDisabled() { | ||
Map<String, String> parameters = ImmutableMap | ||
.of(SKIP_VALIDATE_MISSING_PARAMETERS, "false", SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2"); | ||
Connector connector = getConnector(parameters); | ||
AwsConnectorExecutor executor = getExecutor(connector); | ||
|
||
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet | ||
.builder() | ||
.parameters(Map.of("input", "You are a ${parameters.role}")) | ||
.actionType(PREDICT) | ||
.build(); | ||
String actionType = inputDataSet.getActionType().toString(); | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(); | ||
|
||
Exception exception = Assert | ||
.assertThrows( | ||
IllegalArgumentException.class, | ||
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener) | ||
); | ||
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role"); | ||
} | ||
|
||
@Test | ||
public void executePreparePayloadAndInvoke_SkipValidateMissingParameterEnabled() { | ||
Map<String, String> parameters = ImmutableMap | ||
.of(SKIP_VALIDATE_MISSING_PARAMETERS, "true", SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2"); | ||
Connector connector = getConnector(parameters); | ||
AwsConnectorExecutor executor = getExecutor(connector); | ||
|
||
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet | ||
.builder() | ||
.parameters(Map.of("input", "You are a ${parameters.role}")) | ||
.actionType(PREDICT) | ||
.build(); | ||
String actionType = inputDataSet.getActionType().toString(); | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(); | ||
|
||
executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener); | ||
Mockito | ||
.verify(executor, times(1)) | ||
.invokeRemoteService(any(), any(), any(), argThat(argument -> argument.contains("You are a ${parameters.role}")), any(), any()); | ||
} | ||
|
||
@Test | ||
public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault() { | ||
Map<String, String> parameters = ImmutableMap.of(SERVICE_NAME_FIELD, "sagemaker", REGION_FIELD, "us-west-2"); | ||
Connector connector = getConnector(parameters); | ||
AwsConnectorExecutor executor = getExecutor(connector); | ||
|
||
RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet | ||
.builder() | ||
.parameters(Map.of("input", "You are a ${parameters.role}")) | ||
.actionType(PREDICT) | ||
.build(); | ||
String actionType = inputDataSet.getActionType().toString(); | ||
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(); | ||
|
||
Exception exception = Assert | ||
.assertThrows( | ||
IllegalArgumentException.class, | ||
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener) | ||
); | ||
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters