Skip to content

Commit

Permalink
Question Answering model without translate output
Browse files Browse the repository at this point in the history
Signed-off-by: TrungBui59 <[email protected]>
  • Loading branch information
TrungBui59 committed Dec 8, 2023
1 parent 2761d7d commit 5989acf
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package org.opensearch.ml.engine.algorithms.question_answering;

import java.util.ArrayList;
import java.util.List;

import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.QuestionAnswerInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.DLModel;

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;

public class QuestionAnswerModel extends DLModel {

@Override
public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
MLInputDataset inputDataSet = mlInput.getInputDataset();
List<ModelTensors> tensorOutputs = new ArrayList<>();
Output output;
QuestionAnswerInputDataSet ds = (QuestionAnswerInputDataSet) inputDataSet;
String context = ds.getContextDocs();
for (String question : ds.getQuestionsList()) {
Input input = new Input();
input.add(context);
input.add(question);
output = getPredictor().predict(input);
ModelTensors outputTensors = ModelTensors.fromBytes(output.getData().getAsBytes());
tensorOutputs.add(outputTensors);
}
return new ModelTensorOutput(tensorOutputs);
}

@Override
public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) {
return new QuestionAnswerTranslator();
}

@Override
public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) {
return null;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.opensearch.ml.engine.algorithms.question_answering;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.engine.algorithms.text_embedding.SentenceTransformerTextEmbeddingTranslator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslatorContext;

public class QuestionAnswerTranslator extends SentenceTransformerTextEmbeddingTranslator {
private String[] tokens;
private final String QUESTION_ANSWER_NAME = "question_answering";

@Override
public NDList processInput(TranslatorContext ctx, Input input) {
String contextDoc = input.getAsString(0);
String question = input.getAsString(1);
NDManager manager = ctx.getNDManager();
NDList ndList = new NDList();

Encoding encodings = tokenizer.encode(question, contextDoc);
this.tokens = encodings.getTokens();
long[] indices = encodings.getIds();
long[] attentionMask = encodings.getAttentionMask();
long[] tokenTypes = encodings.getTypeIds();

NDArray indicesArray = manager.create(indices);
indicesArray.setName("input_ids");

NDArray attentionMaskArray = manager.create(attentionMask);
attentionMaskArray.setName("attention_mask");

NDArray tokenTypeArray = manager.create(tokenTypes);
tokenTypeArray.setName("token_type_ids");

ndList.add(indicesArray);
ndList.add(attentionMaskArray);
ndList.add(tokenTypeArray);
return ndList;
}

@Override
public Output processOutput(TranslatorContext ctx, NDList list) {
Output output = new Output(200, "OK");
List<ModelTensor> outputs = new ArrayList<>();
Iterator<NDArray> iterator = list.iterator();

while (iterator.hasNext()) {
NDArray ndArray = iterator.next();
String name = QUESTION_ANSWER_NAME;
Number[] data = ndArray.toArray();
long[] shape = ndArray.getShape().getShape();
DataType dataType = ndArray.getDataType();

}
return output;
}

@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}

}

0 comments on commit 5989acf

Please sign in to comment.