diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 4cfd6c6607..a6181e1b2f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -16,7 +16,6 @@ import java.io.IOException; import java.net.URI; -import java.net.URL; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.HashMap; @@ -40,7 +39,6 @@ import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; import com.jayway.jsonpath.JsonPath; @@ -269,13 +267,7 @@ public static SdkHttpFullRequest buildSdkRequest( Map parameters, String payload, SdkHttpMethod method - ) throws Exception { - String endpoint = connector.getPredictEndpoint(parameters); - URL url = new URL(endpoint); - String protocol = url.getProtocol(); - String host = url.getHost(); - int port = url.getPort(); - MLHttpClientFactory.validate(protocol, host, port); + ) { String charset = parameters.getOrDefault("charset", "UTF-8"); RequestBody requestBody; if (payload != null) { @@ -287,6 +279,7 @@ public static SdkHttpFullRequest buildSdkRequest( log.error("Content length is 0. Aborting request to remote model"); throw new IllegalArgumentException("Content length is 0. Aborting request to remote model"); } + String endpoint = connector.getPredictEndpoint(parameters); SdkHttpFullRequest.Builder builder = SdkHttpFullRequest .builder() .method(method) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 2a5ead4b1b..816e0528dc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -9,6 +9,7 @@ import static software.amazon.awssdk.http.SdkHttpMethod.GET; import static software.amazon.awssdk.http.SdkHttpMethod.POST; +import java.net.URL; import java.security.AccessController; import java.security.PrivilegedExceptionAction; import java.time.Duration; @@ -30,6 +31,8 @@ import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; +import com.google.common.annotations.VisibleForTesting; + import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; @@ -87,9 +90,11 @@ public void invokeRemoteModel( switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) { case "POST": log.debug("original payload to remote model: " + payload); + validateHttpClientParameters(parameters); request = ConnectorUtils.buildSdkRequest(connector, parameters, payload, POST); break; case "GET": + validateHttpClientParameters(parameters); request = ConnectorUtils.buildSdkRequest(connector, parameters, null, GET); break; default: @@ -120,4 +125,14 @@ public void invokeRemoteModel( actionListener.onFailure(new MLException("Fail to execute http connector", e)); } } + + @VisibleForTesting + protected void validateHttpClientParameters(Map parameters) throws Exception { + String endpoint = connector.getPredictEndpoint(parameters); + URL url = new URL(endpoint); + String protocol = url.getProtocol(); + String host = url.getHost(); + int port = url.getPort(); + MLHttpClientFactory.validate(protocol, host, port); + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 94b4a7304a..5e1a9dfacb 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -38,7 +38,6 @@ import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.encryptor.Encryptor; @@ -94,43 +93,6 @@ public void executePredict_RemoteInferenceInput_MissingCredential() { .build(); } - @Test - public void executePredict_RemoteInferenceInput_invalidIp() { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test1.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); - Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); - Connector connector = AwsConnector - .awsConnectorBuilder() - .name("test connector") - .version("1") - .protocol("http") - .parameters(parameters) - .credential(credential) - .actions(Arrays.asList(predictAction)) - .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); - AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); - Settings settings = Settings.builder().build(); - threadContext = new ThreadContext(settings); - when(executor.getClient()).thenReturn(client); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - - MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); - Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); - assert exceptionCaptor.getValue() instanceof MLException; - assertEquals("Fail to execute predict in aws connector", exceptionCaptor.getValue().getMessage()); - } - @Test public void executePredict_RemoteInferenceInput_EmptyIpAddress() { ConnectorAction predictAction = ConnectorAction @@ -164,45 +126,8 @@ public void executePredict_RemoteInferenceInput_EmptyIpAddress() { executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); - assert exceptionCaptor.getValue() instanceof IllegalArgumentException; - assertEquals("Remote inference host name has private ip address: ", exceptionCaptor.getValue().getMessage()); - } - - @Test - public void executePredict_RemoteInferenceInput_illegalIpAddress() { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://localhost/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Map credential = ImmutableMap - .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); - Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); - Connector connector = AwsConnector - .awsConnectorBuilder() - .name("test connector") - .version("1") - .protocol("http") - .parameters(parameters) - .credential(credential) - .actions(Arrays.asList(predictAction)) - .build(); - connector.decrypt((c) -> encryptor.decrypt(c)); - AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector)); - Settings settings = Settings.builder().build(); - threadContext = new ThreadContext(settings); - when(executor.getClient()).thenReturn(client); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(threadContext); - - MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); - executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(), actionListener); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); - Mockito.verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); - assert exceptionCaptor.getValue() instanceof IllegalArgumentException; - assertEquals("Remote inference host name has private ip address: localhost", exceptionCaptor.getValue().getMessage()); + assert exceptionCaptor.getValue() instanceof NullPointerException; + assertEquals("host must not be null.", exceptionCaptor.getValue().getMessage()); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java index f0f0292c59..d1d9b42dcc 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java @@ -6,23 +6,13 @@ package org.opensearch.ml.engine.httpclient; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import java.net.MalformedURLException; import java.time.Duration; -import java.util.Arrays; -import java.util.Map; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.opensearch.ml.common.connector.Connector; -import org.opensearch.ml.common.connector.ConnectorAction; -import org.opensearch.ml.common.connector.HttpConnector; -import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils; -import software.amazon.awssdk.http.SdkHttpFullRequest; -import software.amazon.awssdk.http.SdkHttpMethod; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; public class MLHttpClientFactoryTests { @@ -38,162 +28,49 @@ public void test_getSdkAsyncHttpClient_success() { @Test public void test_validateIp_validIp_noException() throws Exception { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://api.openai.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); - assertNotNull(request); + MLHttpClientFactory.validate("http", "api.openai.com", 80); } @Test public void test_validateIp_rarePrivateIp_throwException() throws Exception { try { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://0254.020.00.01/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + MLHttpClientFactory.validate("http", "0254.020.00.01", 80); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://172.1048577/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); - } catch (IllegalArgumentException e) { + MLHttpClientFactory.validate("http", "172.1048577", 80); + } catch (Exception e) { assertNotNull(e); } try { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://2886729729/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + MLHttpClientFactory.validate("http", "2886729729", 80); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://192.11010049/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + MLHttpClientFactory.validate("http", "192.11010049", 80); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://3232300545/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + MLHttpClientFactory.validate("http", "3232300545", 80); } catch (IllegalArgumentException e) { assertNotNull(e); } try { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://0:0:0:0:0:ffff:127.0.0.1/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); - } catch (MalformedURLException e) { + MLHttpClientFactory.validate("http", "0:0:0:0:0:ffff:127.0.0.1", 80); + } catch (IllegalArgumentException e) { assertNotNull(e); } try { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://153.24.76.232/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + MLHttpClientFactory.validate("http", "153.24.76.232", 80); } catch (IllegalArgumentException e) { assertNotNull(e); } @@ -201,84 +78,20 @@ public void test_validateIp_rarePrivateIp_throwException() throws Exception { @Test public void test_validateSchemaAndPort_success() throws Exception { - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://api.openai.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); - assertNotNull(request); + MLHttpClientFactory.validate("http", "api.openai.com", 80); } @Test public void test_validateSchemaAndPort_notAllowedSchema_throwException() throws Exception { expectedException.expect(IllegalArgumentException.class); - expectedException.expectMessage("Protocol is not http or https: ftp"); - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("ftp://api.openai.com:8080/v1/completions") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); - assertNull(request); + MLHttpClientFactory.validate("ftp", "api.openai.com", 80); } @Test public void test_validateSchemaAndPort_portNotInRange_throwException() throws Exception { expectedException.expect(IllegalArgumentException.class); expectedException.expectMessage("Port out of range: 65537"); - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("https://api.openai.com:65537/v1/completions") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); + MLHttpClientFactory.validate("https", "api.openai.com", 65537); } - @Test - public void test_validateSchemaAndPort_portNotANumber_throwException() throws Exception { - expectedException.expect(MalformedURLException.class); - ConnectorAction predictAction = ConnectorAction - .builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("https://api.openai.com:abc/v1/completions") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Connector connector = HttpConnector - .builder() - .name("test connector") - .version("1") - .protocol("http") - .actions(Arrays.asList(predictAction)) - .build(); - ConnectorUtils.buildSdkRequest(connector, Map.of(), "hello world", SdkHttpMethod.POST); - } }