Skip to content

Commit

Permalink
[tokenizer] Adds includeTokenTypes for all translators (#3035)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Apr 26, 2024
1 parent 55c67be commit a356dcf
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,19 @@ public class FillMaskBatchTranslator implements NoBatchifyTranslator<String[], C
private String maskToken;
private long maskTokenId;
private int topK;
private boolean includeTokenTypes;
private Batchifier batchifier;

FillMaskBatchTranslator(
HuggingFaceTokenizer tokenizer, String maskToken, int topK, Batchifier batchifier) {
HuggingFaceTokenizer tokenizer,
String maskToken,
int topK,
boolean includeTokenTypes,
Batchifier batchifier) {
this.tokenizer = tokenizer;
this.maskToken = maskToken;
this.topK = topK;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
Encoding encoding = tokenizer.encode(maskToken, false, false);
maskTokenId = encoding.getIds()[0];
Expand All @@ -52,7 +58,7 @@ public NDList processInput(TranslatorContext ctx, String[] inputs) throws Transl
for (int i = 0; i < encodings.length; ++i) {
long[] indices = encodings[i].getIds();
maskIndices[i] = FillMaskTranslator.getMaskIndex(indices, maskToken, maskTokenId);
batch[i] = encodings[i].toNDList(manager, false);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
}
return batchifier.batchify(batch);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ public class FillMaskTranslator implements Translator<String, Classifications> {
private String maskToken;
private long maskTokenId;
private int topK;
private boolean includeTokenTypes;
private Batchifier batchifier;

FillMaskTranslator(
HuggingFaceTokenizer tokenizer, String maskToken, int topK, Batchifier batchifier) {
HuggingFaceTokenizer tokenizer,
String maskToken,
int topK,
boolean includeTokenTypes,
Batchifier batchifier) {
this.tokenizer = tokenizer;
this.maskToken = maskToken;
this.topK = topK;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
Encoding encoding = tokenizer.encode(maskToken, false, false);
maskTokenId = encoding.getIds()[0];
Expand All @@ -61,7 +67,7 @@ public NDList processInput(TranslatorContext ctx, String input) throws Translate
long[] indices = encoding.getIds();
int maskIndex = getMaskIndex(indices, maskToken, maskTokenId);
ctx.setAttachment("maskIndex", maskIndex);
return encoding.toNDList(ctx.getNDManager(), false);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
}

/** {@inheritDoc} */
Expand All @@ -75,7 +81,8 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) {
@Override
public FillMaskBatchTranslator toBatchTranslator(Batchifier batchifier) {
tokenizer.enableBatch();
return new FillMaskBatchTranslator(tokenizer, maskToken, topK, batchifier);
return new FillMaskBatchTranslator(
tokenizer, maskToken, topK, includeTokenTypes, batchifier);
}

static int getMaskIndex(long[] indices, String maskToken, long maskTokenId)
Expand Down Expand Up @@ -139,6 +146,7 @@ public static final class Builder {
private HuggingFaceTokenizer tokenizer;
private String maskedToken = "[MASK]";
private int topK = 5;
private boolean includeTokenTypes;
private Batchifier batchifier = Batchifier.STACK;

Builder(HuggingFaceTokenizer tokenizer) {
Expand Down Expand Up @@ -167,6 +175,17 @@ public Builder optTopK(int topK) {
return this;
}

/**
* Sets if include token types for the {@link Translator}.
*
* @param includeTokenTypes true to include token types
* @return this builder
*/
public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
this.includeTokenTypes = includeTokenTypes;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
Expand All @@ -186,6 +205,7 @@ public Builder optBatchifier(Batchifier batchifier) {
public void configure(Map<String, ?> arguments) {
optMaskToken(ArgumentsUtil.stringValue(arguments, "maskToken", "[MASK]"));
optTopK(ArgumentsUtil.intValue(arguments, "topK", 5));
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optBatchifier(Batchifier.fromString(batchifierStr));
}
Expand All @@ -197,7 +217,8 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public FillMaskTranslator build() throws IOException {
return new FillMaskTranslator(tokenizer, maskedToken, topK, batchifier);
return new FillMaskTranslator(
tokenizer, maskedToken, topK, includeTokenTypes, batchifier);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ public class TextClassificationBatchTranslator
implements NoBatchifyTranslator<String[], Classifications[]> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier;
private PretrainedConfig config;

TextClassificationBatchTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) {
TextClassificationBatchTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
}

Expand All @@ -56,7 +59,7 @@ public NDList processInput(TranslatorContext ctx, String[] inputs) {
Encoding[] encodings = tokenizer.batchEncode(inputs);
NDList[] batch = new NDList[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
batch[i] = encodings[i].toNDList(manager, false);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
}
return batchifier.batchify(batch);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@
public class TextClassificationTranslator implements Translator<String, Classifications> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier;
private PretrainedConfig config;

TextClassificationTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) {
TextClassificationTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
}

Expand All @@ -63,7 +66,7 @@ public void prepare(TranslatorContext ctx) throws IOException {
@Override
public NDList processInput(TranslatorContext ctx, String input) {
Encoding encoding = tokenizer.encode(input);
return encoding.toNDList(ctx.getNDManager(), false);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
}

/** {@inheritDoc} */
Expand All @@ -76,7 +79,7 @@ public Classifications processOutput(TranslatorContext ctx, NDList list) {
@Override
public TextClassificationBatchTranslator toBatchTranslator(Batchifier batchifier) {
tokenizer.enableBatch();
return new TextClassificationBatchTranslator(tokenizer, batchifier);
return new TextClassificationBatchTranslator(tokenizer, includeTokenTypes, batchifier);
}

static Classifications toClassifications(PretrainedConfig config, NDList list) {
Expand Down Expand Up @@ -127,12 +130,24 @@ public static Builder builder(HuggingFaceTokenizer tokenizer, Map<String, ?> arg
public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier = Batchifier.STACK;

Builder(HuggingFaceTokenizer tokenizer) {
this.tokenizer = tokenizer;
}

/**
* Sets if include token types for the {@link Translator}.
*
* @param includeTokenTypes true to include token types
* @return this builder
*/
public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
this.includeTokenTypes = includeTokenTypes;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
Expand All @@ -150,6 +165,7 @@ public Builder optBatchifier(Batchifier batchifier) {
* @param arguments the model arguments
*/
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optBatchifier(Batchifier.fromString(batchifierStr));
}
Expand All @@ -161,7 +177,7 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public TextClassificationTranslator build() throws IOException {
return new TextClassificationTranslator(tokenizer, batchifier);
return new TextClassificationTranslator(tokenizer, includeTokenTypes, batchifier);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ public class TokenClassificationBatchTranslator
implements NoBatchifyTranslator<String[], NamedEntity[][]> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier;
private PretrainedConfig config;

TokenClassificationBatchTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) {
TokenClassificationBatchTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
}

Expand All @@ -58,7 +61,7 @@ public NDList processInput(TranslatorContext ctx, String[] inputs) {
ctx.setAttachment("encodings", encodings);
NDList[] batch = new NDList[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
batch[i] = encodings[i].toNDList(manager, false);
batch[i] = encodings[i].toNDList(manager, includeTokenTypes);
}
return batchifier.batchify(batch);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@
public class TokenClassificationTranslator implements Translator<String, NamedEntity[]> {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier;
private PretrainedConfig config;

TokenClassificationTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) {
TokenClassificationTranslator(
HuggingFaceTokenizer tokenizer, boolean includeTokenTypes, Batchifier batchifier) {
this.tokenizer = tokenizer;
this.includeTokenTypes = includeTokenTypes;
this.batchifier = batchifier;
}

Expand All @@ -65,7 +68,7 @@ public void prepare(TranslatorContext ctx) throws IOException {
public NDList processInput(TranslatorContext ctx, String input) {
Encoding encoding = tokenizer.encode(input);
ctx.setAttachment("encoding", encoding);
return encoding.toNDList(ctx.getNDManager(), false);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
}

/** {@inheritDoc} */
Expand All @@ -79,7 +82,7 @@ public NamedEntity[] processOutput(TranslatorContext ctx, NDList list) {
@Override
public TokenClassificationBatchTranslator toBatchTranslator(Batchifier batchifier) {
tokenizer.enableBatch();
return new TokenClassificationBatchTranslator(tokenizer, batchifier);
return new TokenClassificationBatchTranslator(tokenizer, includeTokenTypes, batchifier);
}

/**
Expand Down Expand Up @@ -139,12 +142,24 @@ static NamedEntity[] toNamedEntities(Encoding encoding, NDList list, PretrainedC
public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private boolean includeTokenTypes;
private Batchifier batchifier = Batchifier.STACK;

Builder(HuggingFaceTokenizer tokenizer) {
this.tokenizer = tokenizer;
}

/**
* Sets if include token types for the {@link Translator}.
*
* @param includeTokenTypes true to include token types
* @return this builder
*/
public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
this.includeTokenTypes = includeTokenTypes;
return this;
}

/**
* Sets the {@link Batchifier} for the {@link Translator}.
*
Expand All @@ -162,6 +177,7 @@ public Builder optBatchifier(Batchifier batchifier) {
* @param arguments the model arguments
*/
public void configure(Map<String, ?> arguments) {
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
String batchifierStr = ArgumentsUtil.stringValue(arguments, "batchifier", "stack");
optBatchifier(Batchifier.fromString(batchifierStr));
}
Expand All @@ -173,7 +189,7 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public TokenClassificationTranslator build() throws IOException {
return new TokenClassificationTranslator(tokenizer, batchifier);
return new TokenClassificationTranslator(tokenizer, includeTokenTypes, batchifier);
}
}
}

0 comments on commit a356dcf

Please sign in to comment.