Skip to content

Commit

Permalink
Finalizing input for question answering model
Browse files Browse the repository at this point in the history
Signed-off-by: TrungBui59 <[email protected]>
  • Loading branch information
TrungBui59 committed Nov 29, 2023
1 parent d1498e8 commit a52807d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ public class MLInput implements Input {
// Input text sentences for text embedding model
public static final String TEXT_DOCS_FIELD = "text_docs";

// Input context docs for question answering model
public static final String CONTEXT_DOCS = "context_docs";
// Input list of questions for question answering model
public static final String QUESTIONS_LIST = "questions_list";

// Algorithm name
protected FunctionName algorithm;
// ML algorithm parameters
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package org.opensearch.ml.common.input.nlp;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.QuestionAnswerInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;


@org.opensearch.ml.common.annotation.MLInput(functionNames=FunctionName.QUESTION_ANSWER)
public class QuestionAnswerMLInput extends MLInput {
Expand All @@ -22,4 +29,63 @@ public QuestionAnswerMLInput(StreamInput input) throws IOException {
public void writeTo(StreamOutput output) throws IOException {
super.writeTo(output);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ALGORITHM_FIELD, algorithm.name());
if(parameters != null) {
builder.field(ML_PARAMETERS_FIELD, parameters);
}
if(inputDataset != null) {
QuestionAnswerInputDataSet ds = (QuestionAnswerInputDataSet) this.inputDataset;
List<String> questionsList = ds.getQuestionsList();
String contextDoc = ds.getContextDocs();
builder.field(CONTEXT_DOCS, contextDoc);
if (questionsList != null && !questionsList.isEmpty()) {
builder.startArray(QUESTIONS_LIST);
for(String d : questionsList) {
builder.value(d);
}
builder.endArray();
}
}
builder.endObject();
return builder;
}

public QuestionAnswerMLInput(XContentParser parser, FunctionName functionName) throws IOException {
super();
this.algorithm = functionName;
List<String> questionsList = new ArrayList<>();
String contextDoc = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case QUESTIONS_LIST:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
String context = parser.text();
questionsList.add(context);
}
break;
case CONTEXT_DOCS:
contextDoc = parser.text();
default:
parser.skipChildren();
break;
}
}
if(questionsList.isEmpty()) {
throw new IllegalArgumentException("No question list");
}
if(contextDoc == null) {
throw new IllegalArgumentException("No context documents");
}
inputDataset = new QuestionAnswerInputDataSet(contextDoc, questionsList);
}
}

0 comments on commit a52807d

Please sign in to comment.