Skip to content

Commit

Permalink
support question answering model
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Mar 16, 2024
1 parent c233356 commit 0f3fe61
Show file tree
Hide file tree
Showing 10 changed files with 492 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public enum FunctionName {
SPARSE_ENCODING,
SPARSE_TOKENIZE,
TEXT_SIMILARITY,
QUESTION_ANSWERING,
AGENT;

public static FunctionName from(String value) {
Expand All @@ -42,7 +43,8 @@ public static FunctionName from(String value) {
TEXT_EMBEDDING,
TEXT_SIMILARITY,
SPARSE_ENCODING,
SPARSE_TOKENIZE
SPARSE_TOKENIZE,
QUESTION_ANSWERING
));

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
}
}
MLInputDataset inputDataSet = null;
if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.SPARSE_TOKENIZE) {
if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.SPARSE_TOKENIZE || algorithm == FunctionName.QUESTION_ANSWERING) {
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* ML input class which supports a list fo text docs.
* This class can be used for TEXT_EMBEDDING model.
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.SPARSE_TOKENIZE})
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_EMBEDDING, FunctionName.SPARSE_ENCODING, FunctionName.SPARSE_TOKENIZE, FunctionName.QUESTION_ANSWERING})
public class TextDocsMLInput extends MLInput {
public static final String TEXT_DOCS_FIELD = "text_docs";
public static final String RESULT_FILTER_FIELD = "result_filter";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ public ModelTensor(String name, Number[] data, long[] shape, MLResultDataType da
this.dataAsMap = dataAsMap;
}

public ModelTensor(String name, String result) {
this.name = name;
this.result = result;
}

Check warning on line 69 in common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java#L66-L69

Added lines #L66 - L69 were not covered by tests

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ private void testClassLoader_MLInput_DlModel(FunctionName functionName) throws I
@Test
public void testClassLoader_MLInput() throws IOException {
testClassLoader_MLInput_DlModel(FunctionName.TEXT_EMBEDDING);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_TOKENIZE);
testClassLoader_MLInput_DlModel(FunctionName.QUESTION_ANSWERING);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_ENCODING);
testClassLoader_MLInput_DlModel(FunctionName.SPARSE_TOKENIZE);
}

@Test(expected = IllegalArgumentException.class)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.algorithms.question_answering;

import static org.opensearch.ml.engine.ModelHelper.*;

import java.util.ArrayList;
import java.util.List;

import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.DLModel;
import org.opensearch.ml.engine.annotation.Function;

import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import lombok.extern.log4j.Log4j2;

@Log4j2
@Function(FunctionName.QUESTION_ANSWERING)
public class QuestionAnsweringModel extends DLModel {

@Override
public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
String question = "How is the weather?";
String context = "The weather is nice, it is beautiful day.";
Input input = new Input();
input.add(question);
input.add(context);

// First request takes longer time. Predict once to warm up model.
predictor.predict(input);
}

@Override
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
MLInputDataset inputDataSet = mlInput.getInputDataset();
List<ModelTensors> tensorOutputs = new ArrayList<>();
Output output;
TextDocsInputDataSet textDocsInput = (TextDocsInputDataSet) inputDataSet;
ModelResultFilter resultFilter = textDocsInput.getResultFilter();
String question = textDocsInput.getDocs().get(0);
String context = textDocsInput.getDocs().get(1);
Input input = new Input();
input.add(question);
input.add(context);
output = getPredictor().predict(input);
tensorOutputs.add(parseModelTensorOutput(output, resultFilter));
return new ModelTensorOutput(tensorOutputs);
}

@Override
public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) throws IllegalArgumentException {
return new QuestionAnsweringTranslator();
}

@Override
public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) {
return null;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.algorithms.question_answering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.TranslatorContext;

public class QuestionAnsweringTranslator extends SentenceTransformerTranslator {
// private static final int[] AXIS = {0};
private List<String> tokens;

@Override
public NDList processInput(TranslatorContext ctx, Input input) {
NDManager manager = ctx.getNDManager();
String question = input.getAsString(0);
String paragraph = input.getAsString(1);
NDList ndList = new NDList();

Encoding encodings = tokenizer.encode(question, paragraph);
tokens = Arrays.asList(encodings.getTokens());
ctx.setAttachment("encoding", encodings);
long[] indices = encodings.getIds();
long[] attentionMask = encodings.getAttentionMask();

NDArray indicesArray = manager.create(indices);
indicesArray.setName("input_ids");

NDArray attentionMaskArray = manager.create(attentionMask);
attentionMaskArray.setName("attention_mask");

ndList.add(indicesArray);
ndList.add(attentionMaskArray);
return ndList;
}

/** {@inheritDoc} */
@Override
public Output processOutput(TranslatorContext ctx, NDList list) {
Output output = new Output(200, "OK");

List<ModelTensor> outputs = new ArrayList<>();

NDArray startLogits = list.get(0);
NDArray endLogits = list.get(1);
int startIdx = (int) startLogits.argMax().getLong();
int endIdx = (int) endLogits.argMax().getLong();
if (startIdx >= endIdx) {
int tmp = startIdx;
startIdx = endIdx;
endIdx = tmp;
}
String answer = tokenizer.buildSentence(tokens.subList(startIdx, endIdx + 1));

outputs.add(new ModelTensor("answer", answer));

ModelTensors modelTensorOutput = new ModelTensors(outputs);
output.add(modelTensorOutput.toBytes());
return output;
}

}
Loading

0 comments on commit 0f3fe61

Please sign in to comment.