Skip to content

Commit

Permalink
extract an abstract model for sparse and dense sentence transformer t…
Browse files Browse the repository at this point in the history
…ranslator

Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Sep 26, 2023
1 parent 4e3dc78 commit a837080
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 73 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package org.opensearch.ml.engine.algorithms;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.Batchifier;
import ai.djl.translate.ServingTranslator;
import ai.djl.translate.TranslatorContext;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

abstract public class SentenceTransformerTranslator implements ServingTranslator {
protected HuggingFaceTokenizer tokenizer;

@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
@Override
public void prepare(TranslatorContext ctx) throws IOException {
Path path = ctx.getModel().getModelPath();
tokenizer = HuggingFaceTokenizer.builder().optPadding(true).optTokenizerPath(path.resolve("tokenizer.json")).build();
}

@Override
public NDList processInput(TranslatorContext ctx, Input input) {
String sentence = input.getAsString(0);
NDManager manager = ctx.getNDManager();
NDList ndList = new NDList();
Encoding encodings = tokenizer.encode(sentence);
long[] indices = encodings.getIds();
long[] attentionMask = encodings.getAttentionMask();

NDArray indicesArray = manager.create(indices);
indicesArray.setName("input1.input_ids");

NDArray attentionMaskArray = manager.create(attentionMask);
attentionMaskArray.setName("input1.attention_mask");

ndList.add(indicesArray);
ndList.add(attentionMaskArray);
return ndList;
}

@Override
public void setArguments(Map<String, ?> arguments) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator;

import java.io.IOException;
import java.nio.ByteBuffer;
Expand All @@ -27,38 +28,7 @@

import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY;

public class SparseEncodingTranslator implements ServingTranslator {
private HuggingFaceTokenizer tokenizer;

@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
@Override
public void prepare(TranslatorContext ctx) throws IOException {
Path path = ctx.getModel().getModelPath();
tokenizer = HuggingFaceTokenizer.builder().optPadding(true).optTokenizerPath(path.resolve("tokenizer.json")).build();
}

@Override
public NDList processInput(TranslatorContext ctx, Input input) {
String sentence = input.getAsString(0);
NDManager manager = ctx.getNDManager();
NDList ndList = new NDList();
Encoding encodings = tokenizer.encode(sentence);
long[] indices = encodings.getIds();
long[] attentionMask = encodings.getAttentionMask();

NDArray indicesArray = manager.create(indices);
indicesArray.setName("input1.input_ids");

NDArray attentionMaskArray = manager.create(attentionMask);
attentionMaskArray.setName("input1.attention_mask");

ndList.add(indicesArray);
ndList.add(attentionMaskArray);
return ndList;
}
public class SparseEncodingTranslator extends SentenceTransformerTranslator {
private Map<String, Float> convertOutput(NDArray array)
{
Map<String, Float> map = new HashMap<>();
Expand Down Expand Up @@ -94,8 +64,4 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
output.add(modelTensorOutput.toBytes());
return output;
}

@Override
public void setArguments(Map<String, ?> arguments) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.SentenceTransformerTranslator;

import java.io.IOException;
import java.nio.ByteBuffer;
Expand All @@ -28,39 +29,7 @@
import java.util.List;
import java.util.Map;

public class SentenceTransformerTextEmbeddingTranslator implements ServingTranslator {
private HuggingFaceTokenizer tokenizer;

@Override
public Batchifier getBatchifier() {
return Batchifier.STACK;
}
@Override
public void prepare(TranslatorContext ctx) throws IOException {
Path path = ctx.getModel().getModelPath();
tokenizer = HuggingFaceTokenizer.builder().optPadding(true).optTokenizerPath(path.resolve("tokenizer.json")).build();
}

@Override
public NDList processInput(TranslatorContext ctx, Input input) {
String sentence = input.getAsString(0);
NDManager manager = ctx.getNDManager();
NDList ndList = new NDList();
Encoding encodings = tokenizer.encode(sentence);
long[] indices = encodings.getIds();
long[] attentionMask = encodings.getAttentionMask();

NDArray indicesArray = manager.create(indices);
indicesArray.setName("input1.input_ids");

NDArray attentionMaskArray = manager.create(attentionMask);
attentionMaskArray.setName("input1.attention_mask");

ndList.add(indicesArray);
ndList.add(attentionMaskArray);
return ndList;
}

public class SentenceTransformerTextEmbeddingTranslator extends SentenceTransformerTranslator {
@Override
public Output processOutput(TranslatorContext ctx, NDList list) {
Output output = new Output(200, "OK");
Expand Down Expand Up @@ -89,8 +58,4 @@ public Output processOutput(TranslatorContext ctx, NDList list) {
output.add(modelTensorOutput.toBytes());
return output;
}

@Override
public void setArguments(Map<String, ?> arguments) {
}
}

0 comments on commit a837080

Please sign in to comment.