diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 22eb7cc843..139ecf82be 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -156,6 +156,9 @@ public class CommonValue { + AbstractConnector.CREDENTIAL_FIELD + "\" : {\"type\": \"flat_object\"},\n" + " \"" + + AbstractConnector.CLIENT_CONFIG_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + AbstractConnector.ACTIONS_FIELD + "\" : {\"type\": \"flat_object\"}\n"; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index 5fa213db99..fadab3ef9a 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -43,6 +43,8 @@ public abstract class AbstractConnector implements Connector { public static final String BACKEND_ROLES_FIELD = "backend_roles"; public static final String OWNER_FIELD = "owner"; public static final String ACCESS_FIELD = "access"; + public static final String CLIENT_CONFIG_FIELD = "client_config"; + protected String name; protected String description; @@ -65,6 +67,8 @@ public abstract class AbstractConnector implements Connector { protected AccessMode access; protected Instant createdTime; protected Instant lastUpdateTime; + @Setter + protected ConnectorClientConfig connectorClientConfig; protected Map createPredictDecryptedHeaders(Map headers) { if (headers == null) { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java index ed9c64ac94..4052b45874 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java @@ -31,8 +31,10 @@ public class AwsConnector extends HttpConnector { @Builder(builderMethodName = "awsConnectorBuilder") public AwsConnector(String name, String description, String version, String protocol, Map parameters, Map credential, List actions, - List backendRoles, AccessMode accessMode, User owner) { - super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, owner); + List backendRoles, AccessMode accessMode, User owner, + ConnectorClientConfig connectorClientConfig) { + super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, + owner, connectorClientConfig); validate(); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index 15b7456f3c..38a18b3882 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -53,6 +53,9 @@ public interface Connector extends ToXContentObject, Writeable { Map getParameters(); List getActions(); + + ConnectorClientConfig getConnectorClientConfig(); + String getPredictEndpoint(Map parameters); String getPredictHttpMethod(); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java new file mode 100644 index 0000000000..bf29271e2e --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorClientConfig.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.connector; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +@Getter +@EqualsAndHashCode +public class ConnectorClientConfig implements ToXContentObject, Writeable { + + public static final String MAX_CONNECTION_FIELD = "max_connection"; + public static final String CONNECTION_TIMEOUT_FIELD = "connection_timeout"; + public static final String READ_TIMEOUT_FIELD = "read_timeout"; + + public static final Integer MAX_CONNECTION_DEFAULT_VALUE = Integer.valueOf(30); + public static final Integer CONNECTION_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000); + public static final Integer READ_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000); + + private Integer maxConnections; + private Integer connectionTimeout; + private Integer readTimeout; + + @Builder(toBuilder = true) + public ConnectorClientConfig( + Integer maxConnections, + Integer connectionTimeout, + Integer readTimeout + ) { + this.maxConnections = maxConnections; + this.connectionTimeout = connectionTimeout; + this.readTimeout = readTimeout; + + } + + public ConnectorClientConfig(StreamInput input) throws IOException { + this.maxConnections = input.readOptionalInt(); + this.connectionTimeout = input.readOptionalInt(); + this.readTimeout = input.readOptionalInt(); + } + + public ConnectorClientConfig() { + this.maxConnections = MAX_CONNECTION_DEFAULT_VALUE; + this.connectionTimeout = CONNECTION_TIMEOUT_DEFAULT_VALUE; + this.readTimeout = READ_TIMEOUT_DEFAULT_VALUE; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + out.writeOptionalInt(maxConnections); + out.writeOptionalInt(connectionTimeout); + out.writeOptionalInt(readTimeout); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + XContentBuilder builder = xContentBuilder.startObject(); + if (maxConnections != null) { + builder.field(MAX_CONNECTION_FIELD, maxConnections); + } + if (connectionTimeout != null) { + builder.field(CONNECTION_TIMEOUT_FIELD, connectionTimeout); + } + if (readTimeout != null) { + builder.field(READ_TIMEOUT_FIELD, readTimeout); + } + return builder.endObject(); + } + + public static ConnectorClientConfig fromStream(StreamInput in) throws IOException { + ConnectorClientConfig connectorClientConfig = new ConnectorClientConfig(in); + return connectorClientConfig; + } + + public static ConnectorClientConfig parse(XContentParser parser) throws IOException { + Integer maxConnections = null; + Integer connectionTimeout = null; + Integer readTimeout = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case MAX_CONNECTION_FIELD: + maxConnections = parser.intValue(); + break; + case CONNECTION_TIMEOUT_FIELD: + connectionTimeout = parser.intValue(); + break; + case READ_TIMEOUT_FIELD: + readTimeout = parser.intValue(); + break; + default: + parser.skipChildren(); + break; + } + } + return ConnectorClientConfig.builder() + .maxConnections(maxConnections) + .connectionTimeout(connectionTimeout) + .readTimeout(readTimeout) + .build(); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index ca37bc40dc..734602478b 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -52,7 +52,8 @@ public class HttpConnector extends AbstractConnector { @Builder public HttpConnector(String name, String description, String version, String protocol, Map parameters, Map credential, List actions, - List backendRoles, AccessMode accessMode, User owner) { + List backendRoles, AccessMode accessMode, User owner, + ConnectorClientConfig connectorClientConfig) { validateProtocol(protocol); this.name = name; this.description = description; @@ -64,6 +65,8 @@ public HttpConnector(String name, String description, String version, String pro this.backendRoles = backendRoles; this.access = accessMode; this.owner = owner; + this.connectorClientConfig = connectorClientConfig; + } public HttpConnector(String protocol, XContentParser parser) throws IOException { @@ -121,6 +124,9 @@ public HttpConnector(String protocol, XContentParser parser) throws IOException case LAST_UPDATED_TIME_FIELD: lastUpdateTime = Instant.ofEpochMilli(parser.longValue()); break; + case CLIENT_CONFIG_FIELD: + connectorClientConfig = ConnectorClientConfig.parse(parser); + break; default: parser.skipChildren(); break; @@ -167,6 +173,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (lastUpdateTime != null) { builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli()); } + if (connectorClientConfig != null) { + builder.field(CLIENT_CONFIG_FIELD, connectorClientConfig); + } builder.endObject(); return builder; } @@ -205,6 +214,11 @@ private void parseFromStream(StreamInput input) throws IOException { if (input.readBoolean()) { this.owner = new User(input); } + this.createdTime = input.readOptionalInstant(); + this.lastUpdateTime = input.readOptionalInstant(); + if (input.readBoolean()) { + this.connectorClientConfig = new ConnectorClientConfig(input); + } } @Override @@ -247,6 +261,14 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + out.writeOptionalInstant(createdTime); + out.writeOptionalInstant(lastUpdateTime); + if (connectorClientConfig != null) { + out.writeBoolean(true); + connectorClientConfig.writeTo(out); + } else { + out.writeBoolean(false); + } } @Override @@ -279,6 +301,9 @@ public void update(MLCreateConnectorInput updateContent, Function backendRoles; private Boolean addAllBackendRoles; private AccessMode access; - private boolean dryRun = false; - private boolean updateConnector = false; + private boolean dryRun; + private boolean updateConnector; + private ConnectorClientConfig connectorClientConfig; + @Builder(toBuilder = true) public MLCreateConnectorInput(String name, @@ -70,7 +76,9 @@ public MLCreateConnectorInput(String name, Boolean addAllBackendRoles, AccessMode access, boolean dryRun, - boolean updateConnector + boolean updateConnector, + ConnectorClientConfig connectorClientConfig + ) { if (!dryRun && !updateConnector) { if (name == null) { @@ -95,6 +103,8 @@ public MLCreateConnectorInput(String name, this.access = access; this.dryRun = dryRun; this.updateConnector = updateConnector; + this.connectorClientConfig = connectorClientConfig; + } public static MLCreateConnectorInput parse(XContentParser parser) throws IOException { @@ -113,6 +123,7 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update Boolean addAllBackendRoles = null; AccessMode access = null; boolean dryRun = false; + ConnectorClientConfig connectorClientConfig = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -161,12 +172,16 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update case DRY_RUN_FIELD: dryRun = parser.booleanValue(); break; + case AbstractConnector.CLIENT_CONFIG_FIELD: + connectorClientConfig = ConnectorClientConfig.parse(parser); + break; default: parser.skipChildren(); break; } } - return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun, updateConnector); + return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, + backendRoles, addAllBackendRoles, access, dryRun, updateConnector, connectorClientConfig); } @Override @@ -202,6 +217,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (access != null) { builder.field(ACCESS_MODE_FIELD, access); } + if (connectorClientConfig != null) { + builder.field(AbstractConnector.CLIENT_CONFIG_FIELD, connectorClientConfig); + } builder.endObject(); return builder; } @@ -248,6 +266,13 @@ public void writeTo(StreamOutput output) throws IOException { } output.writeBoolean(dryRun); output.writeBoolean(updateConnector); + if (connectorClientConfig != null) { + output.writeBoolean(true); + connectorClientConfig.writeTo(output); + } else { + output.writeBoolean(false); + } + } public MLCreateConnectorInput(StreamInput input) throws IOException { @@ -277,5 +302,8 @@ public MLCreateConnectorInput(StreamInput input) throws IOException { } dryRun = input.readBoolean(); updateConnector = input.readBoolean(); + if (input.readBoolean()) { + this.connectorClientConfig = new ConnectorClientConfig(input); + } } } diff --git a/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java b/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java index 10c5d0515f..bd7a214dae 100644 --- a/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java +++ b/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java @@ -55,17 +55,19 @@ public void toXContent_InternalConnector() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\",\"algorithm\":\"REMOTE\"," + - "\"model_version\":\"1.0.0\",\"description\":\"test model\",\"connector\":{\"name\":\"test_connector_name\"," + - "\"version\":\"1\",\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + assertEquals("{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\"," + + "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\"," + + "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + "\"headers\":{\"api_key\":\"${credential.key}\"}," + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"]," + - "\"access\":\"public\"}}", mlModelContent); + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000}}}", + mlModelContent); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorClientConfigTest.java b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorClientConfigTest.java new file mode 100644 index 0000000000..f28c8fb169 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorClientConfigTest.java @@ -0,0 +1,74 @@ +package org.opensearch.ml.common.connector; + +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.settings.Settings; +import org.opensearch.search.SearchModule; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.TestHelper; + +import java.io.IOException; +import java.util.Collections; + +public class ConnectorClientConfigTest { + + @Test + public void writeTo_ReadFromStream() throws IOException { + ConnectorClientConfig config = ConnectorClientConfig.builder() + .maxConnections(10) + .connectionTimeout(5000) + .readTimeout(3000) + .build(); + + BytesStreamOutput output = new BytesStreamOutput(); + config.writeTo(output); + ConnectorClientConfig readConfig = new ConnectorClientConfig(output.bytes().streamInput()); + + Assert.assertEquals(config, readConfig); + } + + @Test + public void toXContent() throws IOException { + ConnectorClientConfig config = ConnectorClientConfig.builder() + .maxConnections(10) + .connectionTimeout(5000) + .readTimeout(3000) + .build(); + + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + config.toXContent(builder, ToXContent.EMPTY_PARAMS); + String content = TestHelper.xContentBuilderToString(builder); + + String expectedJson = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000}"; + Assert.assertEquals(expectedJson, content); + } + + @Test + public void parse() throws IOException { + String jsonStr = "{\"max_connection\":10,\"connection_timeout\":5000,\"read_timeout\":3000}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); + + ConnectorClientConfig config = ConnectorClientConfig.parse(parser); + + Assert.assertEquals(Integer.valueOf(10), config.getMaxConnections()); + Assert.assertEquals(Integer.valueOf(5000), config.getConnectionTimeout()); + Assert.assertEquals(Integer.valueOf(3000), config.getReadTimeout()); + } + + @Test + public void testDefaultValues() { + ConnectorClientConfig config = ConnectorClientConfig.builder().build(); + + Assert.assertNull(config.getMaxConnections()); + Assert.assertNull(config.getConnectionTimeout()); + Assert.assertNull(config.getReadTimeout()); + } +} + diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index f6bbebf8a5..48e19aea76 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -30,6 +30,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Optional; import java.util.function.Function; public class HttpConnectorTest { @@ -39,6 +40,17 @@ public class HttpConnectorTest { Function encryptFunction; Function decryptFunction; + String TEST_CONNECTOR_JSON_STRING = "{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000}}"; + @Before public void setUp() { encryptFunction = s -> "encrypted: "+s.toLowerCase(Locale.ROOT); @@ -71,33 +83,15 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); connector.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"name\":\"test_connector_name\",\"version\":\"1\",\"description\":\"this is a test connector\"," + - "\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"}," + - "\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"]," + - "\"access\":\"public\"}", content); + Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content); } @Test public void constructor_Parser() throws IOException { - String jsonStr = "{\"name\":\"test_connector_name\",\"version\":\"1\",\"description\":\"this is a test connector\"," + - "\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"}," + - "\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"]," + - "\"access\":\"public\"}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + Collections.emptyList()).getNamedXContents()), null, TEST_CONNECTOR_JSON_STRING); parser.nextToken(); HttpConnector connector = new HttpConnector("http", parser); @@ -299,6 +293,8 @@ public static HttpConnector createHttpConnectorWithRequestBody(String requestBod Map credential = new HashMap<>(); credential.put("key", "test_key_value"); + ConnectorClientConfig httpClientConfig = new ConnectorClientConfig(30, 30000, 30000); + HttpConnector connector = HttpConnector.builder() .name("test_connector_name") .description("this is a test connector") @@ -309,6 +305,7 @@ public static HttpConnector createHttpConnectorWithRequestBody(String requestBod .actions(Arrays.asList(action)) .backendRoles(Arrays.asList("role1", "role2")) .accessMode(AccessMode.PUBLIC) + .connectorClientConfig(httpClientConfig) .build(); return connector; } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java index 1b665d0c6b..f73b818c6c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java @@ -57,16 +57,17 @@ public void toXContentTest() throws IOException { mlConnectorGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"name\":\"test_connector_name\"," + - "\"version\":\"1\",\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + assertEquals("{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + "\"headers\":{\"api_key\":\"${credential.key}\"}," + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"]," + - "\"access\":\"public\"}", jsonStr); + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30," + + "\"connection_timeout\":30000,\"read_timeout\":30000}}", jsonStr); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index 0f0b1ad82b..34844b316b 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -5,11 +5,6 @@ package org.opensearch.ml.common.transport.connector; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; - import java.io.IOException; import java.util.Arrays; import java.util.Collections; @@ -24,19 +19,33 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.ConnectorClientConfig; import org.opensearch.ml.common.connector.MLPostProcessFunction; import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.search.SearchModule; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + public class MLCreateConnectorInputTests { private MLCreateConnectorInput mlCreateConnectorInput; private MLCreateConnectorInput mlCreateDryRunConnectorInput; @@ -52,7 +61,8 @@ public class MLCreateConnectorInputTests { "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," + - "\"access_mode\":\"PUBLIC\"}"; + "\"access_mode\":\"PUBLIC\",\"client_config\":{\"max_connection\":20," + + "\"connection_timeout\":10000,\"read_timeout\":10000}}"; @Before public void setUp(){ @@ -65,6 +75,7 @@ public void setUp(){ String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; ConnectorAction action = new ConnectorAction(actionType, method, url, headers, mlCreateConnectorRequestBody, preProcessFunction, postProcessFunction); + ConnectorClientConfig connectorClientConfig = new ConnectorClientConfig(20, 10000, 10000); mlCreateConnectorInput = MLCreateConnectorInput.builder() .name("test_connector_name") @@ -77,6 +88,7 @@ public void setUp(){ .access(AccessMode.PUBLIC) .backendRoles(Arrays.asList("role1", "role2")) .addAllBackendRoles(false) + .connectorClientConfig(connectorClientConfig) .build(); mlCreateDryRunConnectorInput = MLCreateConnectorInput.builder() @@ -160,6 +172,9 @@ public void testToXContent_NullFields() throws Exception { public void testParse() throws Exception { testParseFromJsonString(expectedInputStr, parsedInput -> { assertEquals("test_connector_name", parsedInput.getName()); + assertEquals(20, parsedInput.getConnectorClientConfig().getMaxConnections().intValue()); + assertEquals(10000, parsedInput.getConnectorClientConfig().getReadTimeout().intValue()); + assertEquals(10000, parsedInput.getConnectorClientConfig().getConnectionTimeout().intValue()); }); } @@ -206,6 +221,7 @@ public void readInputStream_SuccessWithNullFields() throws IOException { readInputStream(mlCreateMinimalConnectorInput, parsedInput -> { assertEquals(mlCreateMinimalConnectorInput.getName(), parsedInput.getName()); assertNull(parsedInput.getActions()); + assertNull(parsedInput.getConnectorClientConfig()); }); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java index 5f7626dd83..56a7e81656 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java @@ -58,7 +58,9 @@ public class MLRegisterModelInputTest { "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"},\"is_hidden\":false}"; + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000," + + "\"read_timeout\":30000}},\"is_hidden\":false}"; private final FunctionName functionName = FunctionName.LINEAR_REGRESSION; private final String modelName = "modelName"; private final String version = "version"; @@ -169,18 +171,20 @@ public void testToXContent() throws Exception { @Test public void testToXContent_Incomplete() throws Exception { - String expectedIncompleteInputStr = - "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\"," + - "\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"description\":\"test description\"," + - "\"model_content_hash_value\":\"hash_value_test\",\"deploy_model\":true,\"connector\":" + - "{\"name\":\"test_connector_name\",\"version\":\"1\",\"description\":\"this is a test connector\"," + - "\"protocol\":\"http\",\"parameters\":{\"input\":\"test input value\"}," + - "\"credential\":{\"key\":\"test_key_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":" + - "\"POST\",\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"}," + - "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\",\"pre_process_function\":" + - "\"connector.pre_process.openai.embedding\",\"post_process_function\":" + - "\"connector.post_process.openai.embedding\"}],\"backend_roles\":[\"role1\",\"role2\"]," + - "\"access\":\"public\"},\"is_hidden\":false}"; + String expectedIncompleteInputStr = "{\"function_name\":\"LINEAR_REGRESSION\"," + + "\"name\":\"modelName\",\"version\":\"version\",\"model_group_id\":\"modelGroupId\"," + + "\"description\":\"test description\",\"model_content_hash_value\":\"hash_value_test\"," + + "\"deploy_model\":true,\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"}," + + "\"credential\":{\"key\":\"test_key_value\"},\"actions\":[{\"action_type\":\"PREDICT\"," + + "\"method\":\"POST\",\"url\":\"https://test.com\",\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000," + + "\"read_timeout\":30000}},\"is_hidden\":false}"; input.setUrl(null); input.setModelConfig(null); input.setModelFormat(null); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java new file mode 100644 index 0000000000..46c653776d --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.remote; + +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorClientConfig; + +import lombok.Getter; +import lombok.Setter; + +@Setter +@Getter +public abstract class AbstractConnectorExecutor implements RemoteConnectorExecutor { + private ConnectorClientConfig connectorClientConfig; + + public void initialize(Connector connector) { + if (connector.getConnectorClientConfig() != null) { + connectorClientConfig = connector.getConnectorClientConfig(); + } else { + connectorClientConfig = new ConnectorClientConfig(); + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 0e8169ac64..f06450278d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -16,6 +16,7 @@ import java.nio.charset.StandardCharsets; import java.security.AccessController; import java.security.PrivilegedExceptionAction; +import java.time.Duration; import java.util.List; import java.util.Map; @@ -40,15 +41,17 @@ import software.amazon.awssdk.http.HttpExecuteRequest; import software.amazon.awssdk.http.HttpExecuteResponse; import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.SdkHttpConfigurationOption; import software.amazon.awssdk.http.SdkHttpFullRequest; +import software.amazon.awssdk.utils.AttributeMap; @Log4j2 @ConnectorExecutor(AWS_SIGV4) -public class AwsConnectorExecutor implements RemoteConnectorExecutor { +public class AwsConnectorExecutor extends AbstractConnectorExecutor { @Getter private AwsConnector connector; - private final SdkHttpClient httpClient; + private SdkHttpClient httpClient; @Setter @Getter private ScriptService scriptService; @@ -68,7 +71,33 @@ public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) { } public AwsConnectorExecutor(Connector connector) { - this(connector, new DefaultSdkHttpClientBuilder().build()); + super.initialize(connector); + this.connector = (AwsConnector) connector; + Duration connectionTimeout = Duration.ofMillis(super.getConnectorClientConfig().getConnectionTimeout()); + Duration readTimeout = Duration.ofMillis(super.getConnectorClientConfig().getReadTimeout()); + try ( + AttributeMap attributeMap = AttributeMap + .builder() + .put(SdkHttpConfigurationOption.CONNECTION_TIMEOUT, connectionTimeout) + .put(SdkHttpConfigurationOption.READ_TIMEOUT, readTimeout) + .put(SdkHttpConfigurationOption.MAX_CONNECTIONS, super.getConnectorClientConfig().getMaxConnections()) + .build() + ) { + log + .info( + "Initializing aws connector http client with attributes: connectionTimeout={}, readTimeout={}, maxConnections={}", + connectionTimeout, + readTimeout, + super.getConnectorClientConfig().getMaxConnections() + ); + this.httpClient = new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap); + } catch (RuntimeException e) { + log.error("Error initializing AWS connector HTTP client.", e); + throw e; + } catch (Throwable e) { + log.error("Error initializing AWS connector HTTP client.", e); + throw new MLException(e); + } } @Override @@ -95,9 +124,8 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S .contentStreamProvider(request.contentStreamProvider().orElse(null)) .build(); - HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - return httpClient.prepareRequest(executeRequest).call(); - }); + HttpExecuteResponse response = AccessController + .doPrivileged((PrivilegedExceptionAction) () -> httpClient.prepareRequest(executeRequest).call()); int statusCode = response.httpResponse().statusCode(); AbortableInputStream body = null; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index cc3670e5d1..b1ef5b500c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -43,7 +43,7 @@ @Log4j2 @ConnectorExecutor(HTTP) -public class HttpJsonConnectorExecutor implements RemoteConnectorExecutor { +public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor { @Getter private HttpConnector connector; @@ -61,8 +61,22 @@ public class HttpJsonConnectorExecutor implements RemoteConnectorExecutor { @Getter private Client client; + private CloseableHttpClient httpClient; + public HttpJsonConnectorExecutor(Connector connector) { + super.initialize(connector); this.connector = (HttpConnector) connector; + this.httpClient = MLHttpClientFactory + .getCloseableHttpClient( + super.getConnectorClientConfig().getConnectionTimeout(), + super.getConnectorClientConfig().getReadTimeout(), + super.getConnectorClientConfig().getMaxConnections() + ); + } + + public HttpJsonConnectorExecutor(Connector connector, CloseableHttpClient httpClient) { + this(connector); + this.httpClient = httpClient; } @Override @@ -110,7 +124,7 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S } AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - try (CloseableHttpClient httpClient = getHttpClient(); CloseableHttpResponse response = httpClient.execute(request)) { + try (CloseableHttpResponse response = httpClient.execute(request)) { HttpEntity responseEntity = response.getEntity(); String responseBody = EntityUtils.toString(responseEntity); EntityUtils.consume(responseEntity); @@ -137,7 +151,4 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S } } - public CloseableHttpClient getHttpClient() { - return MLHttpClientFactory.getCloseableHttpClient(); - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java index e5177b3697..c981ebc184 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java @@ -14,6 +14,7 @@ import org.apache.http.HttpHost; import org.apache.http.HttpRequest; import org.apache.http.HttpResponse; +import org.apache.http.client.config.RequestConfig; import org.apache.http.conn.UnsupportedSchemeException; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; @@ -28,11 +29,11 @@ @Log4j2 public class MLHttpClientFactory { - public static CloseableHttpClient getCloseableHttpClient() { - return createHttpClient(); + public static CloseableHttpClient getCloseableHttpClient(Integer connectionTimeout, Integer readTimeout, Integer maxConnections) { + return createHttpClient(connectionTimeout, readTimeout, maxConnections); } - private static CloseableHttpClient createHttpClient() { + private static CloseableHttpClient createHttpClient(Integer connectionTimeout, Integer readTimeout, Integer maxConnections) { HttpClientBuilder builder = HttpClientBuilder.create(); // Only allow HTTP and HTTPS schemes @@ -53,6 +54,10 @@ public boolean isRedirected(HttpRequest request, HttpResponse response, HttpCont return false; } }); + builder.setMaxConnTotal(maxConnections); + builder.setMaxConnPerRoute(maxConnections); + RequestConfig requestConfig = RequestConfig.custom().setConnectTimeout(connectionTimeout).setSocketTimeout(readTimeout).build(); + builder.setDefaultRequestConfig(requestConfig); return builder.build(); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java new file mode 100644 index 0000000000..56ad15cbd4 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java @@ -0,0 +1,45 @@ +package org.opensearch.ml.engine.algorithms.remote; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ml.common.connector.AwsConnector; +import org.opensearch.ml.common.connector.ConnectorClientConfig; + +public class AbstractConnectorExecutorTest { + @Mock + private AwsConnector mockConnector; + + private ConnectorClientConfig connectorClientConfig; + + private AbstractConnectorExecutor executor; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + executor = new AwsConnectorExecutor(mockConnector); + connectorClientConfig = new ConnectorClientConfig(); + } + + @Test + public void testValidateWithNullConfig() { + when(mockConnector.getConnectorClientConfig()).thenReturn(null); + executor.initialize(mockConnector); + assertEquals(ConnectorClientConfig.MAX_CONNECTION_DEFAULT_VALUE, executor.getConnectorClientConfig().getMaxConnections()); + assertEquals(ConnectorClientConfig.CONNECTION_TIMEOUT_DEFAULT_VALUE, executor.getConnectorClientConfig().getConnectionTimeout()); + assertEquals(ConnectorClientConfig.READ_TIMEOUT_DEFAULT_VALUE, executor.getConnectorClientConfig().getReadTimeout()); + } + + @Test + public void testValidateWithNonNullConfigButNullValues() { + when(mockConnector.getConnectorClientConfig()).thenReturn(connectorClientConfig); + executor.initialize(mockConnector); + assertEquals(ConnectorClientConfig.MAX_CONNECTION_DEFAULT_VALUE, executor.getConnectorClientConfig().getMaxConnections()); + assertEquals(ConnectorClientConfig.CONNECTION_TIMEOUT_DEFAULT_VALUE, executor.getConnectorClientConfig().getConnectionTimeout()); + assertEquals(ConnectorClientConfig.READ_TIMEOUT_DEFAULT_VALUE, executor.getConnectorClientConfig().getReadTimeout()); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 0119169e2a..69df5b03ae 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -36,6 +36,7 @@ import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.ConnectorClientConfig; import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; @@ -296,4 +297,35 @@ public void executePredict_TextDocsInferenceInput() throws IOException { Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key")); } + + @Test + public void test_initialize() { + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + ConnectorClientConfig httpClientConfig = new ConnectorClientConfig(20, 30000, 30000); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .connectorClientConfig(httpClientConfig) + .build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + Assert.assertEquals(20, executor.getConnector().getConnectorClientConfig().getMaxConnections().intValue()); + Assert.assertEquals(30000, executor.getConnector().getConnectorClientConfig().getConnectionTimeout().intValue()); + Assert.assertEquals(30000, executor.getConnector().getConnectorClientConfig().getReadTimeout().intValue()); + + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index ad42c1a3ca..122402fea3 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -94,7 +94,7 @@ public void invokeRemoteModel_WrongHttpMethod() { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector); + HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector, httpClient); executor.invokeRemoteModel(null, null, null, null); } @@ -114,7 +114,7 @@ public void executePredict_RemoteInferenceInput() throws IOException { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(executor.getClient()).thenReturn(client); @@ -125,7 +125,6 @@ public void executePredict_RemoteInferenceInput() throws IOException { when(response.getEntity()).thenReturn(entity); StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); when(response.getStatusLine()).thenReturn(statusLine); - when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); ModelTensorOutput modelTensorOutput = executor .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); @@ -160,13 +159,12 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(executor.getClient()).thenReturn(client); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); @@ -203,13 +201,12 @@ public void executePredict_TextDocsInput_LimitExceed() throws IOException { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(executor.getClient()).thenReturn(client); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); } @@ -238,7 +235,7 @@ public void executePredict_TextDocsInput() throws IOException { .protocol("http") .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(executor.getClient()).thenReturn(client); @@ -278,7 +275,6 @@ public void executePredict_TextDocsInput() throws IOException { when(response.getStatusLine()).thenReturn(statusLine); HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); - when(executor.getHttpClient()).thenReturn(httpClient); when(executor.getConnector()).thenReturn(connector); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor @@ -324,7 +320,7 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOE .parameters(parameters) .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(executor.getClient()).thenReturn(client); @@ -355,7 +351,6 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs() throws IOE when(response.getStatusLine()).thenReturn(statusLine); HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); - when(executor.getHttpClient()).thenReturn(httpClient); when(executor.getConnector()).thenReturn(connector); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor @@ -399,7 +394,7 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs_InvalidStepS .parameters(parameters) .actions(Arrays.asList(predictAction)) .build(); - HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector, httpClient)); Settings settings = Settings.builder().build(); threadContext = new ThreadContext(settings); when(executor.getClient()).thenReturn(client); @@ -430,7 +425,6 @@ public void executePredict_TextDocsInput_LessEmbeddingThanInputDocs_InvalidStepS when(response.getStatusLine()).thenReturn(statusLine); HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); - when(executor.getHttpClient()).thenReturn(httpClient); when(executor.getConnector()).thenReturn(connector); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java index 7e11c6fc22..9c98be7ee0 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java @@ -22,7 +22,7 @@ public class MLHttpClientFactoryTests { @Test public void test_getCloseableHttpClient_success() { - CloseableHttpClient client = MLHttpClientFactory.getCloseableHttpClient(); + CloseableHttpClient client = MLHttpClientFactory.getCloseableHttpClient(1000, 1000, 30); assertNotNull(client); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 0c59f3d2bc..1c14284697 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -31,6 +31,11 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { + "\"name\": \"OpenAI Connector\",\n" + "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + "\"version\": 1,\n" + + "\"client_config\": {\n" + + " \"max_connection\": 20,\n" + + " \"connection_timeout\": 50000,\n" + + " \"read_timeout\": 50000\n" + + " },\n" + "\"protocol\": \"http\",\n" + "\"parameters\": {\n" + " \"endpoint\": \"api.openai.com\",\n" @@ -259,6 +264,11 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep + " \"name\": \"OpenAI chat model Connector\",\n" + " \"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" + " \"version\": 1,\n" + + "\"client_config\": {\n" + + " \"max_connection\": 20,\n" + + " \"connection_timeout\": 50000,\n" + + " \"read_timeout\": 50000\n" + + " },\n" + " \"protocol\": \"http\",\n" + " \"parameters\": {\n" + " \"endpoint\": \"api.openai.com\",\n" @@ -320,6 +330,11 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { + " \"description\": \"The connector to public OpenAI edit model service\",\n" + " \"version\": 1,\n" + " \"protocol\": \"http\",\n" + + "\"client_config\": {\n" + + " \"max_connection\": 20,\n" + + " \"connection_timeout\": 50000,\n" + + " \"read_timeout\": 50000\n" + + " },\n" + " \"parameters\": {\n" + " \"endpoint\": \"api.openai.com\",\n" + " \"auth\": \"API_Key\",\n" @@ -385,6 +400,11 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio + " \"name\": \"OpenAI moderations model Connector\",\n" + " \"description\": \"The connector to public OpenAI moderations model service\",\n" + " \"version\": 1,\n" + + "\"client_config\": {\n" + + " \"max_connection\": 20,\n" + + " \"connection_timeout\": 50000,\n" + + " \"read_timeout\": 50000\n" + + " },\n" + " \"protocol\": \"http\",\n" + " \"parameters\": {\n" + " \"endpoint\": \"api.openai.com\",\n" @@ -472,6 +492,11 @@ private void testOpenAITextEmbeddingModel(String charset, Consumer verifyRe + " \"name\": \"OpenAI text embedding model Connector\",\n" + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" + " \"version\": 1,\n" + + "\"client_config\": {\n" + + " \"max_connection\": 20,\n" + + " \"connection_timeout\": 50000,\n" + + " \"read_timeout\": 50000\n" + + " },\n" + " \"protocol\": \"http\",\n" + " \"parameters\": {\n" + " \"model\": \"text-embedding-ada-002\"\n" @@ -535,6 +560,11 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti + " \"name\": \"Cohere generate text model Connector\",\n" + " \"description\": \"The connector to public Cohere generate text model service\",\n" + " \"version\": 1,\n" + + "\"client_config\": {\n" + + " \"max_connection\": 20,\n" + + " \"connection_timeout\": 50000,\n" + + " \"read_timeout\": 50000\n" + + " },\n" + " \"protocol\": \"http\",\n" + " \"parameters\": {\n" + " \"endpoint\": \"api.cohere.ai\",\n" @@ -600,6 +630,11 @@ public void testCohereClassifyModel() throws IOException, InterruptedException { + " \"name\": \"Cohere classify model Connector\",\n" + " \"description\": \"The connector to public Cohere classify model service\",\n" + " \"version\": 1,\n" + + " \"client_config\": {\n" + + " \"max_connection\": 20,\n" + + " \"connection_timeout\": 50000,\n" + + " \"read_timeout\": 50000\n" + + " },\n" + " \"protocol\": \"http\",\n" + " \"parameters\": {\n" + " \"endpoint\": \"api.cohere.ai\",\n" diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java index 81b52818b4..bbf733ca62 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java @@ -58,6 +58,8 @@ public class RestMLUpdateConnectorActionTests extends OpenSearchTestCase { @Mock RestChannel channel; + private String REST_PATH = "/_plugins/_ml/connectors/{connector_id}"; + @Before public void setup() { MockitoAnnotations.openMocks(this); @@ -92,7 +94,7 @@ public void testRoutes() { assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); assertEquals(RestRequest.Method.PUT, route.getMethod()); - assertEquals("/_plugins/_ml/connectors/{connector_id}", route.getPath()); + assertEquals(REST_PATH, route.getPath()); } public void testUpdateConnectorRequest() throws Exception { @@ -109,6 +111,7 @@ public void testUpdateConnectorRequest() throws Exception { assertEquals("test_connectorId", updateConnectorRequest.getConnectorId()); assertEquals("This is test description", updateConnectorRequest.getUpdateContent().getDescription()); assertEquals("2", updateConnectorRequest.getUpdateContent().getVersion()); + } public void testUpdateConnectorRequestWithParsingException() throws Exception { @@ -149,7 +152,7 @@ private RestRequest getRestRequest() { params.put("connector_id", "test_connectorId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/connectors/{connector_id}") + .withPath(REST_PATH) .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -163,7 +166,7 @@ private RestRequest getRestRequestWithNullValue() { params.put("connector_id", "test_connectorId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/connectors/{connector_id}") + .withPath(REST_PATH) .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -176,7 +179,7 @@ private RestRequest getRestRequestWithEmptyContent() { params.put("connector_id", "test_connectorId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/connectors/{connector_id}") + .withPath(REST_PATH) .withParams(params) .withContent(new BytesArray(""), XContentType.JSON) .build(); @@ -190,7 +193,7 @@ private RestRequest getRestRequestWithNullConnectorId() { Map params = new HashMap<>(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/connectors/{connector_id}") + .withPath(REST_PATH) .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build();