From 5989acf62fbcd8acc97c54c878117a688e155c30 Mon Sep 17 00:00:00 2001 From: TrungBui59 Date: Fri, 8 Dec 2023 18:08:01 -0500 Subject: [PATCH] Question Answering model without translate output Signed-off-by: TrungBui59 --- .../QuestionAnswerModel.java | 50 +++++++++++++ .../QuestionAnswerTranslator.java | 74 +++++++++++++++++++ 2 files changed, 124 insertions(+) create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnswerModel.java create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnswerTranslator.java diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnswerModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnswerModel.java new file mode 100644 index 0000000000..8edac1e961 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnswerModel.java @@ -0,0 +1,50 @@ +package org.opensearch.ml.engine.algorithms.question_answering; + +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.QuestionAnswerInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.engine.algorithms.DLModel; + +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.translate.TranslateException; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; + +public class QuestionAnswerModel extends DLModel { + + @Override + public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException { + MLInputDataset inputDataSet = mlInput.getInputDataset(); + List tensorOutputs = new ArrayList<>(); + Output output; + QuestionAnswerInputDataSet ds = (QuestionAnswerInputDataSet) inputDataSet; + String context = ds.getContextDocs(); + for (String question : ds.getQuestionsList()) { + Input input = new Input(); + input.add(context); + input.add(question); + output = getPredictor().predict(input); + ModelTensors outputTensors = ModelTensors.fromBytes(output.getData().getAsBytes()); + tensorOutputs.add(outputTensors); + } + return new ModelTensorOutput(tensorOutputs); + } + + @Override + public Translator getTranslator(String engine, MLModelConfig modelConfig) { + return new QuestionAnswerTranslator(); + } + + @Override + public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) { + return null; + } + +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnswerTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnswerTranslator.java new file mode 100644 index 0000000000..3eb8d5a69b --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnswerTranslator.java @@ -0,0 +1,74 @@ +package org.opensearch.ml.engine.algorithms.question_answering; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.engine.algorithms.text_embedding.SentenceTransformerTextEmbeddingTranslator; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.modality.Input; +import ai.djl.modality.Output; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.translate.Batchifier; +import ai.djl.translate.TranslatorContext; + +public class QuestionAnswerTranslator extends SentenceTransformerTextEmbeddingTranslator { + private String[] tokens; + private final String QUESTION_ANSWER_NAME = "question_answering"; + + @Override + public NDList processInput(TranslatorContext ctx, Input input) { + String contextDoc = input.getAsString(0); + String question = input.getAsString(1); + NDManager manager = ctx.getNDManager(); + NDList ndList = new NDList(); + + Encoding encodings = tokenizer.encode(question, contextDoc); + this.tokens = encodings.getTokens(); + long[] indices = encodings.getIds(); + long[] attentionMask = encodings.getAttentionMask(); + long[] tokenTypes = encodings.getTypeIds(); + + NDArray indicesArray = manager.create(indices); + indicesArray.setName("input_ids"); + + NDArray attentionMaskArray = manager.create(attentionMask); + attentionMaskArray.setName("attention_mask"); + + NDArray tokenTypeArray = manager.create(tokenTypes); + tokenTypeArray.setName("token_type_ids"); + + ndList.add(indicesArray); + ndList.add(attentionMaskArray); + ndList.add(tokenTypeArray); + return ndList; + } + + @Override + public Output processOutput(TranslatorContext ctx, NDList list) { + Output output = new Output(200, "OK"); + List outputs = new ArrayList<>(); + Iterator iterator = list.iterator(); + + while (iterator.hasNext()) { + NDArray ndArray = iterator.next(); + String name = QUESTION_ANSWER_NAME; + Number[] data = ndArray.toArray(); + long[] shape = ndArray.getShape().getShape(); + DataType dataType = ndArray.getDataType(); + + } + return output; + } + + @Override + public Batchifier getBatchifier() { + return Batchifier.STACK; + } + +}