Skip to content

Commit

Permalink
Merge branch 'opensearch-project:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
owaiskazi19 authored Mar 27, 2024
2 parents ab1e054 + f67eab0 commit 7655c5e
Show file tree
Hide file tree
Showing 119 changed files with 5,911 additions and 352 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @b4sjoo @dhrubo-os @jngz-es @model-collapse @rbhavna @wujunshen @ylwu-amzn @zane-neo @Zhangxunmt @austintlee @HenryL27 @samuel-oci
* @b4sjoo @dhrubo-os @jngz-es @model-collapse @rbhavna @ylwu-amzn @zane-neo @Zhangxunmt @austintlee @HenryL27 @samuel-oci @xinyual
3 changes: 2 additions & 1 deletion MAINTAINERS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ This document contains a list of maintainers in this repo. See [opensearch-proje
| Jing Zhang | [jngz-es](https://github.com/jngz-es) | Amazon |
| Junshen Wu | [wujunshen](https://github.com/wujunshen) | Amazon |
| Sicheng Song | [b4sjoo](https://github.com/b4sjoo) | Amazon |
| Xinyuan Lu | [xinyual](https://github.com/xinyual) | Amazon |
| Xun Zhang | [Zhangxunmt](https://github.com/Zhangxunmt) | Amazon |
| Yaliang Wu | [ylwu-amzn](https://github.com/ylwu-amzn) | Amazon |
| Zan Niu | [zane-neo](https://github.com/zane-neo) | Amazon |
| Austin Lee | [austintlee](https://github.com/austintlee) | Aryn |
| Henry Lindeman | [HenryL27](https://github.com/HenryL27) | Aryn |
| Samuel Herman | [samuel-oci](https://github.com/samuel-oci/) | Oracle |
| Samuel Herman | [samuel-oci](https://github.com/samuel-oci/) | Oracle |

## Emeritus

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.ToolMetadata;
Expand Down Expand Up @@ -484,7 +485,7 @@ public void deleteConnector() {

@Test
public void testRegisterAgent() {
MLAgent mlAgent = MLAgent.builder().name("Agent name").build();
MLAgent mlAgent = MLAgent.builder().name("Agent name").type(MLAgentType.FLOW.name()).build();
assertEquals(registerAgentResponse, machineLearningClient.registerAgent(mlAgent).actionGet());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
Expand Down Expand Up @@ -869,7 +870,7 @@ public void testRegisterAgent() {
}).when(client).execute(eq(MLRegisterAgentAction.INSTANCE), any(), any());

ArgumentCaptor<MLRegisterAgentResponse> argumentCaptor = ArgumentCaptor.forClass(MLRegisterAgentResponse.class);
MLAgent mlAgent = MLAgent.builder().name("Agent name").build();
MLAgent mlAgent = MLAgent.builder().name("Agent name").type(MLAgentType.FLOW.name()).build();

machineLearningNodeClient.registerAgent(mlAgent, registerAgentResponseActionListener);

Expand Down
49 changes: 46 additions & 3 deletions common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ public class CommonValue {
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 9;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 10;
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 3;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2;
public static final String ML_CONTROLLER_INDEX = ".plugins-ml-controller";
public static final Integer ML_CONTROLLER_INDEX_SCHEMA_VERSION = 1;
public static final String ML_MAP_RESPONSE_KEY = "response";
public static final String ML_AGENT_INDEX = ".plugins-ml-agent";
public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 1;
public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 2;
public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta";
public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1;
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
Expand Down Expand Up @@ -261,6 +261,46 @@ public class CommonValue {
+ " \""
+ MLModel.LAST_UNDEPLOYED_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
+ MLModel.GUARDRAILS_FIELD
+ "\" : {\n" +
" \"properties\": {\n" +
" \"input_guardrail\": {\n" +
" \"properties\": {\n" +
" \"regex\": {\n" +
" \"type\": \"text\"\n" +
" },\n" +
" \"stop_words\": {\n" +
" \"properties\": {\n" +
" \"index_name\": {\n" +
" \"type\": \"text\"\n" +
" },\n" +
" \"source_fields\": {\n" +
" \"type\": \"text\"\n" +
" }\n" +
" }\n" +
" }\n" +
" }\n" +
" },\n" +
" \"output_guardrail\": {\n" +
" \"properties\": {\n" +
" \"regex\": {\n" +
" \"type\": \"text\"\n" +
" },\n" +
" \"stop_words\": {\n" +
" \"properties\": {\n" +
" \"index_name\": {\n" +
" \"type\": \"text\"\n" +
" },\n" +
" \"source_fields\": {\n" +
" \"type\": \"text\"\n" +
" }\n" +
" }\n" +
" }\n" +
" }\n" +
" }\n" +
" }\n" +
" },\n"
+ " \""
+ MLModel.CONNECTOR_FIELD
+ "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n},"
Expand Down Expand Up @@ -416,6 +456,9 @@ public class CommonValue {
+ MLAgent.MEMORY_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
+ " \""
+ MLAgent.IS_HIDDEN_FIELD
+ "\": {\"type\": \"boolean\"},\n"
+ " \""
+ MLAgent.CREATED_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n"
+ " \""
Expand Down
15 changes: 13 additions & 2 deletions common/src/main/java/org/opensearch/ml/common/FunctionName.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common;

import java.util.HashSet;
import java.util.Locale;
import java.util.Set;

// Please strictly add new FunctionName to the last line
Expand All @@ -28,11 +29,12 @@ public enum FunctionName {
SPARSE_ENCODING,
SPARSE_TOKENIZE,
TEXT_SIMILARITY,
QUESTION_ANSWERING,
AGENT;

public static FunctionName from(String value) {
try {
return FunctionName.valueOf(value);
return FunctionName.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong function name");
}
Expand All @@ -42,7 +44,8 @@ public static FunctionName from(String value) {
TEXT_EMBEDDING,
TEXT_SIMILARITY,
SPARSE_ENCODING,
SPARSE_TOKENIZE
SPARSE_TOKENIZE,
QUESTION_ANSWERING
));

/**
Expand All @@ -52,4 +55,12 @@ public static FunctionName from(String value) {
public static boolean isDLModel(FunctionName functionName) {
return DL_MODELS.contains(functionName);
}

public static boolean needDeployFirst(FunctionName functionName) {
return DL_MODELS.contains(functionName) || functionName == REMOTE;
}

public static boolean isAutoDeployEnabled(boolean autoDeploymentEnabled, FunctionName functionName) {
return autoDeploymentEnabled && functionName == FunctionName.REMOTE;
}
}
25 changes: 25 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/MLAgentType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common;

import java.util.Locale;

public enum MLAgentType {
FLOW,
CONVERSATIONAL,
CONVERSATIONAL_FLOW;

public static MLAgentType from(String value) {
if (value == null) {
throw new IllegalArgumentException("Agent type can't be null");
}
try {
return MLAgentType.valueOf(value.toUpperCase(Locale.ROOT));
} catch (Exception e) {
throw new IllegalArgumentException("Wrong Agent type");
}
}
}
30 changes: 28 additions & 2 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.Guardrails;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.controller.MLRateLimiter;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.model.QuestionAnsweringModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.model.MetricsCorrelationModelConfig;

Expand Down Expand Up @@ -84,6 +86,7 @@ public class MLModel implements ToXContentObject {
public static final String IS_HIDDEN_FIELD = "is_hidden";
public static final String CONNECTOR_FIELD = "connector";
public static final String CONNECTOR_ID_FIELD = "connector_id";
public static final String GUARDRAILS_FIELD = "guardrails";

private String name;
private String modelGroupId;
Expand Down Expand Up @@ -116,7 +119,6 @@ public class MLModel implements ToXContentObject {
private Integer totalChunks; // model chunk doc only
private Integer planningWorkerNodeCount; // plan to deploy model to how many nodes
private Integer currentWorkerNodeCount; // model is deployed to how many nodes

private String[] planningWorkerNodes; // plan to deploy model to these nodes
private boolean deployToAllNodes;

Expand All @@ -127,6 +129,7 @@ public class MLModel implements ToXContentObject {
@Setter
private Connector connector;
private String connectorId;
private Guardrails guardrails;

@Builder(toBuilder = true)
public MLModel(String name,
Expand Down Expand Up @@ -158,7 +161,8 @@ public MLModel(String name,
boolean deployToAllNodes,
Boolean isHidden,
Connector connector,
String connectorId) {
String connectorId,
Guardrails guardrails) {
this.name = name;
this.modelGroupId = modelGroupId;
this.algorithm = algorithm;
Expand Down Expand Up @@ -190,6 +194,7 @@ public MLModel(String name,
this.isHidden = isHidden;
this.connector = connector;
this.connectorId = connectorId;
this.guardrails = guardrails;
}

public MLModel(StreamInput input) throws IOException {
Expand All @@ -215,6 +220,8 @@ public MLModel(StreamInput input) throws IOException {
if (input.readBoolean()) {
if (algorithm.equals(FunctionName.METRICS_CORRELATION)) {
modelConfig = new MetricsCorrelationModelConfig(input);
} else if (algorithm.equals(FunctionName.QUESTION_ANSWERING)) {
modelConfig = new QuestionAnsweringModelConfig(input);
} else {
modelConfig = new TextEmbeddingModelConfig(input);
}
Expand Down Expand Up @@ -243,6 +250,9 @@ public MLModel(StreamInput input) throws IOException {
connector = Connector.fromStream(input);
}
connectorId = input.readOptionalString();
if (input.readBoolean()) {
this.guardrails = new Guardrails(input);
}
}
}

Expand Down Expand Up @@ -308,6 +318,12 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalString(connectorId);
if (guardrails != null) {
out.writeBoolean(true);
guardrails.writeTo(out);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down Expand Up @@ -406,6 +422,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (connectorId != null) {
builder.field(CONNECTOR_ID_FIELD, connectorId);
}
if (guardrails != null) {
builder.field(GUARDRAILS_FIELD, guardrails);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -448,6 +467,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
boolean isHidden = false;
Connector connector = null;
String connectorId = null;
Guardrails guardrails = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -510,6 +530,8 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case MODEL_CONFIG_FIELD:
if (FunctionName.METRICS_CORRELATION.name().equals(algorithmName)) {
modelConfig = MetricsCorrelationModelConfig.parse(parser);
} else if (FunctionName.QUESTION_ANSWERING.name().equals(algorithmName)) {
modelConfig = QuestionAnsweringModelConfig.parse(parser);
} else {
modelConfig = TextEmbeddingModelConfig.parse(parser);
}
Expand Down Expand Up @@ -571,6 +593,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case LAST_UNDEPLOYED_TIME_FIELD:
lastUndeployedTime = Instant.ofEpochMilli(parser.longValue());
break;
case GUARDRAILS_FIELD:
guardrails = Guardrails.parse(parser);
break;
default:
parser.skipChildren();
break;
Expand Down Expand Up @@ -608,6 +633,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.isHidden(isHidden)
.connector(connector)
.connectorId(connectorId)
.guardrails(guardrails)
.build();
}

Expand Down
Loading

0 comments on commit 7655c5e

Please sign in to comment.