Skip to content

Commit

Permalink
Merge branch 'opensearch-project:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
owaiskazi19 authored Mar 18, 2024
2 parents 499f6c4 + c233356 commit ab1e054
Show file tree
Hide file tree
Showing 90 changed files with 5,085 additions and 210 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ subprojects {
configurations.all {
// Force spotless depending on newer version of guava due to CVE-2023-2976. Remove after spotless upgrades.
resolutionStrategy.force "com.google.guava:guava:32.1.2-jre"
resolutionStrategy.force 'org.apache.commons:commons-compress:1.25.0'
resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0'
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ public class CommonValue {
+ AbstractConnector.CREDENTIAL_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ " \""
+ AbstractConnector.CLIENT_CONFIG_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ " \""
+ AbstractConnector.ACTIONS_FIELD
+ "\" : {\"type\": \"flat_object\"}\n";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.Set;

@Log4j2
@SuppressWarnings("removal")
public class MLCommonsClassLoader {

private static Map<Enum<?>, Class<?>> parameterClassMap = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ public abstract class AbstractConnector implements Connector {
public static final String BACKEND_ROLES_FIELD = "backend_roles";
public static final String OWNER_FIELD = "owner";
public static final String ACCESS_FIELD = "access";
public static final String CLIENT_CONFIG_FIELD = "client_config";


protected String name;
protected String description;
Expand All @@ -65,6 +67,8 @@ public abstract class AbstractConnector implements Connector {
protected AccessMode access;
protected Instant createdTime;
protected Instant lastUpdateTime;
@Setter
protected ConnectorClientConfig connectorClientConfig;

protected Map<String, String> createPredictDecryptedHeaders(Map<String, String> headers) {
if (headers == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ public class AwsConnector extends HttpConnector {
@Builder(builderMethodName = "awsConnectorBuilder")
public AwsConnector(String name, String description, String version, String protocol,
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode, User owner) {
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, owner);
List<String> backendRoles, AccessMode accessMode, User owner,
ConnectorClientConfig connectorClientConfig) {
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode,
owner, connectorClientConfig);
validate();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ public interface Connector extends ToXContentObject, Writeable {
Map<String, String> getParameters();

List<ConnectorAction> getActions();

ConnectorClientConfig getConnectorClientConfig();

String getPredictEndpoint(Map<String, String> parameters);

String getPredictHttpMethod();
Expand Down Expand Up @@ -107,6 +110,7 @@ static Connector createConnector(XContentBuilder builder, String connectorProtoc
}
}

@SuppressWarnings("removal")
static Connector createConnector(XContentParser parser) throws IOException {
Map<String, Object> connectorMap = parser.map();
String jsonStr;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.connector;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@Getter
@EqualsAndHashCode
public class ConnectorClientConfig implements ToXContentObject, Writeable {

public static final String MAX_CONNECTION_FIELD = "max_connection";
public static final String CONNECTION_TIMEOUT_FIELD = "connection_timeout";
public static final String READ_TIMEOUT_FIELD = "read_timeout";

public static final Integer MAX_CONNECTION_DEFAULT_VALUE = Integer.valueOf(30);
public static final Integer CONNECTION_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000);
public static final Integer READ_TIMEOUT_DEFAULT_VALUE = Integer.valueOf(30000);

private Integer maxConnections;
private Integer connectionTimeout;
private Integer readTimeout;

@Builder(toBuilder = true)
public ConnectorClientConfig(
Integer maxConnections,
Integer connectionTimeout,
Integer readTimeout
) {
this.maxConnections = maxConnections;
this.connectionTimeout = connectionTimeout;
this.readTimeout = readTimeout;

}

public ConnectorClientConfig(StreamInput input) throws IOException {
this.maxConnections = input.readOptionalInt();
this.connectionTimeout = input.readOptionalInt();
this.readTimeout = input.readOptionalInt();
}

public ConnectorClientConfig() {
this.maxConnections = MAX_CONNECTION_DEFAULT_VALUE;
this.connectionTimeout = CONNECTION_TIMEOUT_DEFAULT_VALUE;
this.readTimeout = READ_TIMEOUT_DEFAULT_VALUE;
}

@Override
public void writeTo(StreamOutput out) throws IOException {

out.writeOptionalInt(maxConnections);
out.writeOptionalInt(connectionTimeout);
out.writeOptionalInt(readTimeout);
}

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
XContentBuilder builder = xContentBuilder.startObject();
if (maxConnections != null) {
builder.field(MAX_CONNECTION_FIELD, maxConnections);
}
if (connectionTimeout != null) {
builder.field(CONNECTION_TIMEOUT_FIELD, connectionTimeout);
}
if (readTimeout != null) {
builder.field(READ_TIMEOUT_FIELD, readTimeout);
}
return builder.endObject();
}

public static ConnectorClientConfig fromStream(StreamInput in) throws IOException {
ConnectorClientConfig connectorClientConfig = new ConnectorClientConfig(in);
return connectorClientConfig;
}

public static ConnectorClientConfig parse(XContentParser parser) throws IOException {
Integer maxConnections = null;
Integer connectionTimeout = null;
Integer readTimeout = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case MAX_CONNECTION_FIELD:
maxConnections = parser.intValue();
break;
case CONNECTION_TIMEOUT_FIELD:
connectionTimeout = parser.intValue();
break;
case READ_TIMEOUT_FIELD:
readTimeout = parser.intValue();
break;
default:
parser.skipChildren();
break;
}
}
return ConnectorClientConfig.builder()
.maxConnections(maxConnections)
.connectionTimeout(connectionTimeout)
.readTimeout(readTimeout)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ public class HttpConnector extends AbstractConnector {
@Builder
public HttpConnector(String name, String description, String version, String protocol,
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode, User owner) {
List<String> backendRoles, AccessMode accessMode, User owner,
ConnectorClientConfig connectorClientConfig) {
validateProtocol(protocol);
this.name = name;
this.description = description;
Expand All @@ -64,6 +65,8 @@ public HttpConnector(String name, String description, String version, String pro
this.backendRoles = backendRoles;
this.access = accessMode;
this.owner = owner;
this.connectorClientConfig = connectorClientConfig;

}

public HttpConnector(String protocol, XContentParser parser) throws IOException {
Expand Down Expand Up @@ -121,6 +124,9 @@ public HttpConnector(String protocol, XContentParser parser) throws IOException
case LAST_UPDATED_TIME_FIELD:
lastUpdateTime = Instant.ofEpochMilli(parser.longValue());
break;
case CLIENT_CONFIG_FIELD:
connectorClientConfig = ConnectorClientConfig.parse(parser);
break;
default:
parser.skipChildren();
break;
Expand Down Expand Up @@ -167,6 +173,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (lastUpdateTime != null) {
builder.field(LAST_UPDATED_TIME_FIELD, lastUpdateTime.toEpochMilli());
}
if (connectorClientConfig != null) {
builder.field(CLIENT_CONFIG_FIELD, connectorClientConfig);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -205,6 +214,11 @@ private void parseFromStream(StreamInput input) throws IOException {
if (input.readBoolean()) {
this.owner = new User(input);
}
this.createdTime = input.readOptionalInstant();
this.lastUpdateTime = input.readOptionalInstant();
if (input.readBoolean()) {
this.connectorClientConfig = new ConnectorClientConfig(input);
}
}

@Override
Expand Down Expand Up @@ -247,6 +261,14 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalInstant(createdTime);
out.writeOptionalInstant(lastUpdateTime);
if (connectorClientConfig != null) {
out.writeBoolean(true);
connectorClientConfig.writeTo(out);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down Expand Up @@ -279,6 +301,9 @@ public void update(MLCreateConnectorInput updateContent, Function<String, String
if (updateContent.getAccess() != null) {
this.access = updateContent.getAccess();
}
if (updateContent.getConnectorClientConfig() != null) {
this.connectorClientConfig = updateContent.getConnectorClientConfig();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.OpenSearchParseException;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -49,6 +50,14 @@ public static MLRateLimiter parse(XContentParser parser) throws IOException {
switch (fieldName) {
case LIMIT_FIELD:
limit = parser.text();
try {
double limitNumber = Double.parseDouble(limit);
if (limitNumber < 0) {
throw new OpenSearchParseException("Limit field must be a positive number.");
}
} catch (NumberFormatException e) {
throw new OpenSearchParseException("Limit field must be a positive number.");
}
break;
case UNIT_FIELD:
unit = TimeUnit.valueOf(parser.text());
Expand Down
Loading

0 comments on commit ab1e054

Please sign in to comment.