diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java index 9a4ccba42b5..ee4e9cf9601 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/FillMaskBatchTranslator.java @@ -29,13 +29,19 @@ public class FillMaskBatchTranslator implements NoBatchifyTranslator { private String maskToken; private long maskTokenId; private int topK; + private boolean includeTokenTypes; private Batchifier batchifier; FillMaskTranslator( - HuggingFaceTokenizer tokenizer, String maskToken, int topK, Batchifier batchifier) { + HuggingFaceTokenizer tokenizer, + String maskToken, + int topK, + boolean includeTokenTypes, + Batchifier batchifier) { this.tokenizer = tokenizer; this.maskToken = maskToken; this.topK = topK; + this.includeTokenTypes = includeTokenTypes; this.batchifier = batchifier; Encoding encoding = tokenizer.encode(maskToken, false, false); maskTokenId = encoding.getIds()[0]; @@ -61,7 +67,7 @@ public NDList processInput(TranslatorContext ctx, String input) throws Translate long[] indices = encoding.getIds(); int maskIndex = getMaskIndex(indices, maskToken, maskTokenId); ctx.setAttachment("maskIndex", maskIndex); - return encoding.toNDList(ctx.getNDManager(), false); + return encoding.toNDList(ctx.getNDManager(), includeTokenTypes); } /** {@inheritDoc} */ @@ -75,7 +81,8 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) { @Override public FillMaskBatchTranslator toBatchTranslator(Batchifier batchifier) { tokenizer.enableBatch(); - return new FillMaskBatchTranslator(tokenizer, maskToken, topK, batchifier); + return new FillMaskBatchTranslator( + tokenizer, maskToken, topK, includeTokenTypes, batchifier); } static int getMaskIndex(long[] indices, String maskToken, long maskTokenId) @@ -139,6 +146,7 @@ public static final class Builder { private HuggingFaceTokenizer tokenizer; private String maskedToken = "[MASK]"; private int topK = 5; + private boolean includeTokenTypes; private Batchifier batchifier = Batchifier.STACK; Builder(HuggingFaceTokenizer tokenizer) { @@ -167,6 +175,17 @@ public Builder optTopK(int topK) { return this; } + /** + * Sets if include token types for the {@link Translator}. + * + * @param includeTokenTypes true to include token types + * @return this builder + */ + public Builder optIncludeTokenTypes(boolean includeTokenTypes) { + this.includeTokenTypes = includeTokenTypes; + return this; + } + /** * Sets the {@link Batchifier} for the {@link Translator}. * @@ -186,6 +205,7 @@ public Builder optBatchifier(Batchifier batchifier) { public void configure(Map arguments) { optMaskToken(ArgumentsUtil.stringValue(arguments, "maskToken", "[MASK]")); optTopK(ArgumentsUtil.intValue(arguments, "topK", 5)); + optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack"); optBatchifier(Batchifier.fromString(batchifierStr)); } @@ -197,7 +217,8 @@ public void configure(Map arguments) { * @throws IOException if I/O error occurs */ public FillMaskTranslator build() throws IOException { - return new FillMaskTranslator(tokenizer, maskedToken, topK, batchifier); + return new FillMaskTranslator( + tokenizer, maskedToken, topK, includeTokenTypes, batchifier); } } } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationBatchTranslator.java index 6c9beda2531..72be9a8aa9a 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationBatchTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationBatchTranslator.java @@ -32,11 +32,14 @@ public class TextClassificationBatchTranslator implements NoBatchifyTranslator { private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; private Batchifier batchifier; private PretrainedConfig config; - TextClassificationBatchTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) { + TextClassificationBatchTranslator( + HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; this.batchifier = batchifier; } @@ -56,7 +59,7 @@ public NDList processInput(TranslatorContext ctx, String[] inputs) { Encoding[] encodings = tokenizer.batchEncode(inputs); NDList[] batch = new NDList[encodings.length]; for (int i = 0; i < encodings.length; ++i) { - batch[i] = encodings[i].toNDList(manager, false); + batch[i] = encodings[i].toNDList(manager, includeTokenTypes); } return batchifier.batchify(batch); } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslator.java index d624b69d700..defa0f90b2f 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TextClassificationTranslator.java @@ -35,11 +35,14 @@ public class TextClassificationTranslator implements Translator { private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; private Batchifier batchifier; private PretrainedConfig config; - TextClassificationTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) { + TextClassificationTranslator( + HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; this.batchifier = batchifier; } @@ -63,7 +66,7 @@ public void prepare(TranslatorContext ctx) throws IOException { @Override public NDList processInput(TranslatorContext ctx, String input) { Encoding encoding = tokenizer.encode(input); - return encoding.toNDList(ctx.getNDManager(), false); + return encoding.toNDList(ctx.getNDManager(), includeTokenTypes); } /** {@inheritDoc} */ @@ -76,7 +79,7 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) { @Override public TextClassificationBatchTranslator toBatchTranslator(Batchifier batchifier) { tokenizer.enableBatch(); - return new TextClassificationBatchTranslator(tokenizer, batchifier); + return new TextClassificationBatchTranslator(tokenizer, includeTokenTypes, batchifier); } static Classifications toClassifications(PretrainedConfig config, NDList list) { @@ -127,12 +130,24 @@ public static Builder builder(HuggingFaceTokenizer tokenizer, Map arg public static final class Builder { private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; private Batchifier batchifier = Batchifier.STACK; Builder(HuggingFaceTokenizer tokenizer) { this.tokenizer = tokenizer; } + /** + * Sets if include token types for the {@link Translator}. + * + * @param includeTokenTypes true to include token types + * @return this builder + */ + public Builder optIncludeTokenTypes(boolean includeTokenTypes) { + this.includeTokenTypes = includeTokenTypes; + return this; + } + /** * Sets the {@link Batchifier} for the {@link Translator}. * @@ -150,6 +165,7 @@ public Builder optBatchifier(Batchifier batchifier) { * @param arguments the model arguments */ public void configure(Map arguments) { + optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack"); optBatchifier(Batchifier.fromString(batchifierStr)); } @@ -161,7 +177,7 @@ public void configure(Map arguments) { * @throws IOException if I/O error occurs */ public TextClassificationTranslator build() throws IOException { - return new TextClassificationTranslator(tokenizer, batchifier); + return new TextClassificationTranslator(tokenizer, includeTokenTypes, batchifier); } } } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationBatchTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationBatchTranslator.java index 2ae45438ccd..b3b4ac0d3db 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationBatchTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationBatchTranslator.java @@ -32,11 +32,14 @@ public class TokenClassificationBatchTranslator implements NoBatchifyTranslator { private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; private Batchifier batchifier; private PretrainedConfig config; - TokenClassificationBatchTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) { + TokenClassificationBatchTranslator( + HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; this.batchifier = batchifier; } @@ -58,7 +61,7 @@ public NDList processInput(TranslatorContext ctx, String[] inputs) { ctx.setAttachment("encodings", encodings); NDList[] batch = new NDList[encodings.length]; for (int i = 0; i < encodings.length; ++i) { - batch[i] = encodings[i].toNDList(manager, false); + batch[i] = encodings[i].toNDList(manager, includeTokenTypes); } return batchifier.batchify(batch); } diff --git a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationTranslator.java b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationTranslator.java index b1106c244a6..5e846be147b 100644 --- a/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationTranslator.java +++ b/extensions/tokenizers/src/main/java/ai/djl/huggingface/translator/TokenClassificationTranslator.java @@ -36,11 +36,14 @@ public class TokenClassificationTranslator implements Translator { private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; private Batchifier batchifier; private PretrainedConfig config; - TokenClassificationTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) { + TokenClassificationTranslator( + HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) { this.tokenizer = tokenizer; + this.includeTokenTypes = includeTokenTypes; this.batchifier = batchifier; } @@ -65,7 +68,7 @@ public void prepare(TranslatorContext ctx) throws IOException { public NDList processInput(TranslatorContext ctx, String input) { Encoding encoding = tokenizer.encode(input); ctx.setAttachment("encoding", encoding); - return encoding.toNDList(ctx.getNDManager(), false); + return encoding.toNDList(ctx.getNDManager(), includeTokenTypes); } /** {@inheritDoc} */ @@ -79,7 +82,7 @@ public NamedEntity[] processOutput(TranslatorContext ctx, NDList list) { @Override public TokenClassificationBatchTranslator toBatchTranslator(Batchifier batchifier) { tokenizer.enableBatch(); - return new TokenClassificationBatchTranslator(tokenizer, batchifier); + return new TokenClassificationBatchTranslator(tokenizer, includeTokenTypes, batchifier); } /** @@ -139,12 +142,24 @@ static NamedEntity[] toNamedEntities(Encoding encoding, NDList list, PretrainedC public static final class Builder { private HuggingFaceTokenizer tokenizer; + private boolean includeTokenTypes; private Batchifier batchifier = Batchifier.STACK; Builder(HuggingFaceTokenizer tokenizer) { this.tokenizer = tokenizer; } + /** + * Sets if include token types for the {@link Translator}. + * + * @param includeTokenTypes true to include token types + * @return this builder + */ + public Builder optIncludeTokenTypes(boolean includeTokenTypes) { + this.includeTokenTypes = includeTokenTypes; + return this; + } + /** * Sets the {@link Batchifier} for the {@link Translator}. * @@ -162,6 +177,7 @@ public Builder optBatchifier(Batchifier batchifier) { * @param arguments the model arguments */ public void configure(Map arguments) { + optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes")); String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack"); optBatchifier(Batchifier.fromString(batchifierStr)); } @@ -173,7 +189,7 @@ public void configure(Map arguments) { * @throws IOException if I/O error occurs */ public TokenClassificationTranslator build() throws IOException { - return new TokenClassificationTranslator(tokenizer, batchifier); + return new TokenClassificationTranslator(tokenizer, includeTokenTypes, batchifier); } } }