diff --git a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java index 574f13e9c3..cb79328427 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java @@ -56,6 +56,11 @@ public class MLInput implements Input { // Input text sentences for text embedding model public static final String TEXT_DOCS_FIELD = "text_docs"; + // Input context docs for question answering model + public static final String CONTEXT_DOCS = "context_docs"; + // Input list of questions for question answering model + public static final String QUESTIONS_LIST = "questions_list"; + // Algorithm name protected FunctionName algorithm; // ML algorithm parameters diff --git a/common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnswerMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnswerMLInput.java index d9c553855f..493116cb8b 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnswerMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/nlp/QuestionAnswerMLInput.java @@ -1,12 +1,19 @@ package org.opensearch.ml.common.input.nlp; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.QuestionAnswerInputDataSet; import org.opensearch.ml.common.input.MLInput; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + @org.opensearch.ml.common.annotation.MLInput(functionNames=FunctionName.QUESTION_ANSWER) public class QuestionAnswerMLInput extends MLInput { @@ -22,4 +29,63 @@ public QuestionAnswerMLInput(StreamInput input) throws IOException { public void writeTo(StreamOutput output) throws IOException { super.writeTo(output); } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ALGORITHM_FIELD, algorithm.name()); + if(parameters != null) { + builder.field(ML_PARAMETERS_FIELD, parameters); + } + if(inputDataset != null) { + QuestionAnswerInputDataSet ds = (QuestionAnswerInputDataSet) this.inputDataset; + List questionsList = ds.getQuestionsList(); + String contextDoc = ds.getContextDocs(); + builder.field(CONTEXT_DOCS, contextDoc); + if (questionsList != null && !questionsList.isEmpty()) { + builder.startArray(QUESTIONS_LIST); + for(String d : questionsList) { + builder.value(d); + } + builder.endArray(); + } + } + builder.endObject(); + return builder; + } + + public QuestionAnswerMLInput(XContentParser parser, FunctionName functionName) throws IOException { + super(); + this.algorithm = functionName; + List questionsList = new ArrayList<>(); + String contextDoc = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case QUESTIONS_LIST: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + String context = parser.text(); + questionsList.add(context); + } + break; + case CONTEXT_DOCS: + contextDoc = parser.text(); + default: + parser.skipChildren(); + break; + } + } + if(questionsList.isEmpty()) { + throw new IllegalArgumentException("No question list"); + } + if(contextDoc == null) { + throw new IllegalArgumentException("No context documents"); + } + inputDataset = new QuestionAnswerInputDataSet(contextDoc, questionsList); + } }