Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/TrungBui59/ml-commons into …
Browse files Browse the repository at this point in the history
…question_answering
  • Loading branch information
TrungBui59 committed Dec 8, 2023
2 parents a52807d + 5989acf commit 2ea3c57
Show file tree
Hide file tree
Showing 108 changed files with 6,652 additions and 376 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ buildscript {
plugins {
id 'com.netflix.nebula.ospackage' version "11.5.0"
id 'java'
id "io.freefair.lombok" version "8.0.1"
id "io.freefair.lombok" version "8.4"
id 'jacoco'
}

Expand Down
2 changes: 1 addition & 1 deletion client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ plugins {
id 'jacoco'
id 'com.github.johnrengelman.shadow'
id 'maven-publish'
id 'com.diffplug.spotless' version '6.18.0'
id 'com.diffplug.spotless' version '6.23.0'
id 'signing'
}

Expand Down
2 changes: 1 addition & 1 deletion common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies {
}

lombok {
version = "1.18.28"
version = "1.18.30"
}

jacocoTestReport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class CommonValue {
public static final String ML_MODEL_INDEX = ".plugins-ml-model";
public static final String ML_TASK_INDEX = ".plugins-ml-task";
public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 7;
public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 8;
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2;
Expand Down Expand Up @@ -186,6 +186,9 @@ public class CommonValue {
+ MLModel.DEPLOY_TO_ALL_NODES_FIELD
+ "\": {\"type\": \"boolean\"},\n"
+ " \""
+ MLModel.IS_HIDDEN_FIELD
+ "\": {\"type\": \"boolean\"},\n"
+ " \""
+ MLModel.MODEL_CONFIG_FIELD
+ "\" : {\"properties\":{\""
+ MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
Expand Down
18 changes: 13 additions & 5 deletions common/src/main/java/org/opensearch/ml/common/FunctionName.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.common;

import java.util.HashSet;
import java.util.Set;

public enum FunctionName {
LINEAR_REGRESSION,
KMEANS,
Expand All @@ -17,6 +20,7 @@ public enum FunctionName {
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
TEXT_SIMILARITY,
SPARSE_ENCODING,
SPARSE_TOKENIZE,
METRICS_CORRELATION,
Expand All @@ -31,15 +35,19 @@ public static FunctionName from(String value) {
}
}

private static final HashSet<FunctionName> DL_MODELS = new HashSet<>(Set.of(
TEXT_EMBEDDING,
TEXT_SIMILARITY,
SPARSE_ENCODING,
SPARSE_TOKENIZE,
QUESTION_ANSWER
));

/**
* Check if model is deep learning model.
* @return true for deep learning model.
*/
public static boolean isDLModel(FunctionName functionName) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING ||
functionName == SPARSE_TOKENIZE || functionName == QUESTION_ANSWER) {
return true;
}
return false;
return DL_MODELS.contains(functionName);
}
}
17 changes: 17 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ public class MLModel implements ToXContentObject {
public static final String CURRENT_WORKER_NODE_COUNT_FIELD = "current_worker_node_count";
public static final String PLANNING_WORKER_NODES_FIELD = "planning_worker_nodes";
public static final String DEPLOY_TO_ALL_NODES_FIELD = "deploy_to_all_nodes";

public static final String IS_HIDDEN_FIELD = "is_hidden";
public static final String CONNECTOR_FIELD = "connector";
public static final String CONNECTOR_ID_FIELD = "connector_id";

Expand Down Expand Up @@ -110,6 +112,9 @@ public class MLModel implements ToXContentObject {
private String[] planningWorkerNodes; // plan to deploy model to these nodes
private boolean deployToAllNodes;

//is domain manager creates any special hidden model in the cluster this status will be true. Otherwise,
// False by default
private Boolean isHidden;
@Setter
private Connector connector;
private String connectorId;
Expand Down Expand Up @@ -139,6 +144,7 @@ public MLModel(String name,
Integer currentWorkerNodeCount,
String[] planningWorkerNodes,
boolean deployToAllNodes,
Boolean isHidden,
Connector connector,
String connectorId) {
this.name = name;
Expand Down Expand Up @@ -166,6 +172,7 @@ public MLModel(String name,
this.currentWorkerNodeCount = currentWorkerNodeCount;
this.planningWorkerNodes = planningWorkerNodes;
this.deployToAllNodes = deployToAllNodes;
this.isHidden = isHidden;
this.connector = connector;
this.connectorId = connectorId;
}
Expand Down Expand Up @@ -210,6 +217,7 @@ public MLModel(StreamInput input) throws IOException{
currentWorkerNodeCount = input.readOptionalInt();
planningWorkerNodes = input.readOptionalStringArray();
deployToAllNodes = input.readBoolean();
isHidden = input.readOptionalBoolean();
modelGroupId = input.readOptionalString();
if (input.readBoolean()) {
connector = Connector.fromStream(input);
Expand Down Expand Up @@ -263,6 +271,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalInt(currentWorkerNodeCount);
out.writeOptionalStringArray(planningWorkerNodes);
out.writeBoolean(deployToAllNodes);
out.writeOptionalBoolean(isHidden);
out.writeOptionalString(modelGroupId);
if (connector != null) {
out.writeBoolean(true);
Expand Down Expand Up @@ -351,6 +360,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (deployToAllNodes) {
builder.field(DEPLOY_TO_ALL_NODES_FIELD, deployToAllNodes);
}
if (isHidden != null) {
builder.field(MLModel.IS_HIDDEN_FIELD, isHidden);
}
if (connector != null) {
builder.field(CONNECTOR_FIELD, connector);
}
Expand Down Expand Up @@ -393,6 +405,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
Integer currentWorkerNodeCount = null;
List<String> planningWorkerNodes = new ArrayList<>();
boolean deployToAllNodes = false;
boolean isHidden = false;
Connector connector = null;
String connectorId = null;

Expand Down Expand Up @@ -476,6 +489,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case DEPLOY_TO_ALL_NODES_FIELD:
deployToAllNodes = parser.booleanValue();
break;
case IS_HIDDEN_FIELD:
isHidden = parser.booleanValue();
break;
case CONNECTOR_FIELD:
connector = createConnector(parser);
break;
Expand Down Expand Up @@ -537,6 +553,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.currentWorkerNodeCount(currentWorkerNodeCount)
.planningWorkerNodes(planningWorkerNodes.toArray(new String[0]))
.deployToAllNodes(deployToAllNodes)
.isHidden(isHidden)
.connector(connector)
.connectorId(connectorId)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,25 @@ public class ActionConstants {
/** name of success field in all requests */
public final static String SUCCESS_FIELD = "success";

private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation";
/** path for create conversation */
public final static String CREATE_CONVERSATION_REST_PATH = "/_plugins/_ml/memory/conversation";
/** path for list conversations */
public final static String GET_CONVERSATIONS_REST_PATH = "/_plugins/_ml/memory/conversation";
/** path for put interaction */
public final static String CREATE_INTERACTION_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}";
public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create";
/** path for get conversations */
public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list";
/** path for create interaction */
public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create";
/** path for get interactions */
public final static String GET_INTERACTIONS_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}";
public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list";
/** path for delete conversation */
public final static String DELETE_CONVERSATION_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}";
public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_delete";
/** path for search conversations */
public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search";
/** path for search interactions */
public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search";
/** path for get conversation */
public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}";
/** path for get interaction */
public final static String GET_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/{interaction_id}";

/** default max results returned by get operations */
public final static int DEFAULT_MAX_RESULTS = 10;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ public enum MLInputDataType {
DATA_FRAME,
TEXT_DOCS,
REMOTE,
QUESTION_ANSWER
QUESTION_ANSWER,
TEXT_SIMILARITY,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.common.dataset;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.annotation.InputDataSet;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.FieldDefaults;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@InputDataSet(MLInputDataType.TEXT_SIMILARITY)
public class TextSimilarityInputDataSet extends MLInputDataset {

List<String> textDocs;

String queryText;

@Builder(toBuilder = true)
public TextSimilarityInputDataSet(String queryText, List<String> textDocs) {
super(MLInputDataType.TEXT_SIMILARITY);
Objects.requireNonNull(textDocs);
Objects.requireNonNull(queryText);
if(textDocs.isEmpty()) {
throw new IllegalArgumentException("No text documents were provided");
}
this.textDocs = textDocs;
this.queryText = queryText;
}

public TextSimilarityInputDataSet(StreamInput in) throws IOException {
super(MLInputDataType.TEXT_SIMILARITY);
this.queryText = in.readString();
int size = in.readInt();
this.textDocs = new ArrayList<String>();
for(int i = 0; i < size; i++) {
String context = in.readString();
this.textDocs.add(context);
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(queryText);
out.writeInt(this.textDocs.size());
for (String doc : this.textDocs) {
out.writeString(doc);
}
}
}
25 changes: 25 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/input/MLInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
Expand All @@ -21,6 +22,7 @@
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.search.builder.SearchSourceBuilder;

Expand Down Expand Up @@ -55,6 +57,8 @@ public class MLInput implements Input {
public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions";
// Input text sentences for text embedding model
public static final String TEXT_DOCS_FIELD = "text_docs";
// Input query text to compare against for text similarity model
public static final String QUERY_TEXT_FIELD = "query_text";

// Input context docs for question answering model
public static final String CONTEXT_DOCS = "context_docs";
Expand Down Expand Up @@ -162,6 +166,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0]));
}
}
break;
case TEXT_SIMILARITY:
TextSimilarityInputDataSet ds = (TextSimilarityInputDataSet) this.inputDataset;
List<String> tdocs = ds.getTextDocs();
String queryText = ds.getQueryText();
builder.field(QUERY_TEXT_FIELD, queryText);
if (tdocs != null && !tdocs.isEmpty()) {
builder.startArray(TEXT_DOCS_FIELD);
for(String d : tdocs) {
builder.value(d);
}
builder.endArray();
}
break;
default:
break;
}
Expand Down Expand Up @@ -191,6 +209,7 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
List<String> targetResponse = new ArrayList<>();
List<Integer> targetResponsePositions = new ArrayList<>();
List<String> textDocs = new ArrayList<>();
String queryText = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -238,6 +257,9 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
textDocs.add(parser.text());
}
break;
case QUERY_TEXT_FIELD:
queryText = parser.text();
break;
default:
parser.skipChildren();
break;
Expand All @@ -248,6 +270,9 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
}
if (algorithm == FunctionName.TEXT_SIMILARITY) {
inputDataSet = new TextSimilarityInputDataSet(queryText, textDocs);
}
return new MLInput(algorithm, mlParameters, searchSourceBuilder, sourceIndices, dataFrame, inputDataSet);
}

Expand Down
Loading

0 comments on commit 2ea3c57

Please sign in to comment.