diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index e3b2ea21e2..eb8b805175 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -10,6 +10,7 @@ import java.security.AccessController; import java.security.PrivilegedExceptionAction; +import java.time.Duration; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -36,7 +37,7 @@ @Log4j2 @ConnectorExecutor(AWS_SIGV4) -public class AwsConnectorExecutor implements RemoteConnectorExecutor { +public class AwsConnectorExecutor extends AbstractConnectorExecutor { @Getter private AwsConnector connector; @@ -56,8 +57,12 @@ public class AwsConnectorExecutor implements RemoteConnectorExecutor { private SdkAsyncHttpClient httpClient; public AwsConnectorExecutor(Connector connector) { + super.initialize(connector); this.connector = (AwsConnector) connector; - this.httpClient = MLHttpClientFactory.getAsyncHttpClient(); + Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); + Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); + Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); + this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection); } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index f65f614528..ad7a77308a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -11,6 +11,7 @@ import java.security.AccessController; import java.security.PrivilegedExceptionAction; +import java.time.Duration; import java.util.List; import java.util.Locale; import java.util.Map; @@ -61,7 +62,10 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor { public HttpJsonConnectorExecutor(Connector connector) { super.initialize(connector); this.connector = (HttpConnector) connector; - this.httpClient = MLHttpClientFactory.getAsyncHttpClient(); + Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); + Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); + Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); + this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection); } @Override @@ -103,8 +107,4 @@ public void invokeRemoteModel( actionListener.onFailure(new MLException("Fail to execute http connector", e)); } } - - public SdkAsyncHttpClient getHttpClient(int connectionTimeout, int readTimeout, int maxConnections) { - return MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnections); - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java index eeb68ba86b..339523b313 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java @@ -11,6 +11,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.time.Duration; import java.util.Arrays; import java.util.Locale; @@ -21,11 +22,16 @@ @Log4j2 public class MLHttpClientFactory { - public static SdkAsyncHttpClient getAsyncHttpClient() { + public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) { try { return AccessController .doPrivileged( - (PrivilegedExceptionAction) () -> NettyNioAsyncHttpClient.builder().maxConcurrency(100).build() + (PrivilegedExceptionAction) () -> NettyNioAsyncHttpClient + .builder() + .connectionTimeout(connectionTimeout) + .readTimeout(readTimeout) + .maxConcurrency(maxConnections) + .build() ); } catch (PrivilegedActionException e) { return null;