Skip to content

Commit

Permalink
Change validate localhost logic to same with existing code
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Apr 29, 2024
1 parent 13dc670 commit 05d14bc
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 286 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import java.io.IOException;
import java.net.URI;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -40,7 +39,6 @@
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;

import com.jayway.jsonpath.JsonPath;
Expand Down Expand Up @@ -269,13 +267,7 @@ public static SdkHttpFullRequest buildSdkRequest(
Map<String, String> parameters,
String payload,
SdkHttpMethod method
) throws Exception {
String endpoint = connector.getPredictEndpoint(parameters);
URL url = new URL(endpoint);
String protocol = url.getProtocol();
String host = url.getHost();
int port = url.getPort();
MLHttpClientFactory.validate(protocol, host, port);
) {
String charset = parameters.getOrDefault("charset", "UTF-8");
RequestBody requestBody;
if (payload != null) {
Expand All @@ -287,6 +279,7 @@ public static SdkHttpFullRequest buildSdkRequest(
log.error("Content length is 0. Aborting request to remote model");
throw new IllegalArgumentException("Content length is 0. Aborting request to remote model");
}
String endpoint = connector.getPredictEndpoint(parameters);
SdkHttpFullRequest.Builder builder = SdkHttpFullRequest
.builder()
.method(method)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static software.amazon.awssdk.http.SdkHttpMethod.GET;
import static software.amazon.awssdk.http.SdkHttpMethod.POST;

import java.net.URL;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.time.Duration;
Expand All @@ -30,6 +31,8 @@
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;

import com.google.common.annotations.VisibleForTesting;

import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -87,9 +90,11 @@ public void invokeRemoteModel(
switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) {
case "POST":
log.debug("original payload to remote model: " + payload);
validateHttpClientParameters(parameters);
request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST);
break;
case "GET":
validateHttpClientParameters(parameters);
request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET);
break;
default:
Expand Down Expand Up @@ -120,4 +125,14 @@ public void invokeRemoteModel(
actionListener.onFailure(new MLException("Fail to execute http connector", e));
}
}

@VisibleForTesting
protected void validateHttpClientParameters(Map<String, String> parameters) throws Exception {
String endpoint = connector.getPredictEndpoint(parameters);
URL url = new URL(endpoint);
String protocol = url.getProtocol();
String host = url.getHost();
int port = url.getPort();
MLHttpClientFactory.validate(protocol, host, port);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.encryptor.Encryptor;
Expand Down Expand Up @@ -94,43 +93,6 @@ public void executePredict_RemoteInferenceInput_MissingCredential() {
.build();
}

@Test
public void executePredict_RemoteInferenceInput_invalidIp() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test1.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);

MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener);
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
assert exceptionCaptor.getValue() instanceof MLException;
assertEquals("Fail to execute predict in aws connector", exceptionCaptor.getValue().getMessage());
}

@Test
public void executePredict_RemoteInferenceInput_EmptyIpAddress() {
ConnectorAction predictAction = ConnectorAction
Expand Down Expand Up @@ -164,45 +126,8 @@ public void executePredict_RemoteInferenceInput_EmptyIpAddress() {
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener);
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
assert exceptionCaptor.getValue() instanceof IllegalArgumentException;
assertEquals("Remote inference host name has private ip address: ", exceptionCaptor.getValue().getMessage());
}

@Test
public void executePredict_RemoteInferenceInput_illegalIpAddress() {
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://localhost/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector));
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(executor.getClient()).thenReturn(client);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);

MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener);
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());
assert exceptionCaptor.getValue() instanceof IllegalArgumentException;
assertEquals("Remote inference host name has private ip address: localhost", exceptionCaptor.getValue().getMessage());
assert exceptionCaptor.getValue() instanceof NullPointerException;
assertEquals("host must not be null.", exceptionCaptor.getValue().getMessage());
}

@Test
Expand Down
Loading

0 comments on commit 05d14bc

Please sign in to comment.