Skip to content

Commit

Permalink
Allow overriding special token flags in encode and decode methods (#1855
Browse files Browse the repository at this point in the history
)
  • Loading branch information
siddvenk authored Aug 2, 2022
1 parent 582210f commit 927187b
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 13 deletions.
4 changes: 2 additions & 2 deletions extensions/tokenizers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
_: JObject,
handle: jlong,
ids: jlongArray,
add_special_tokens: jboolean,
skip_special_tokens: jboolean,
) -> jstring {
let tokenizer = cast_handle::<Tokenizer>(handle);
let long_ids = env.get_long_array_elements(ids, ReleaseMode::NoCopyBack).unwrap();
Expand All @@ -396,7 +396,7 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
decode_ids.push(*val as u32);
}
}
let decoding: String = tokenizer.decode(decode_ids, add_special_tokens == JNI_TRUE).unwrap();
let decoding: String = tokenizer.decode(decode_ids, skip_special_tokens == JNI_TRUE).unwrap();
let ret = env
.new_string(decoding)
.expect("Couldn't create java string!")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,35 +140,62 @@ public void close() {
* Returns the {@code Encoding} of the input sentence.
*
* @param text the input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text) {
public Encoding encode(String text, boolean addSpecialTokens) {
long encoding = TokenizersLibrary.LIB.encode(getHandle(), text, addSpecialTokens);
return toEncoding(encoding);
}

/**
* Returns the {@code Encoding} of the input sentence.
*
* @param text the input sentence
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text) {
return encode(text, addSpecialTokens);
}

/**
* Returns the {@code Encoding} of the input sentence.
*
* @param text the input sentence
* @param textPair the second input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text, String textPair) {
public Encoding encode(String text, String textPair, boolean addSpecialTokens) {
long encoding =
TokenizersLibrary.LIB.encodeDual(getHandle(), text, textPair, addSpecialTokens);
return toEncoding(encoding);
}

/**
* Returns the {@code Encoding} of the input sentence.
*
* @param text the input sentence
* @param textPair the second input sentence
* @return the {@code Encoding} of the input sentence
*/
public Encoding encode(String text, String textPair) {
return encode(text, textPair, addSpecialTokens);
}

/**
* Returns the {@code Encoding} of the input sentences.
*
* @param inputs the input sentences
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @return the {@code Encoding} of the input sentences
*/
public Encoding encode(List<String> inputs) {
public Encoding encode(List<String> inputs, boolean addSpecialTokens) {
String[] array = inputs.toArray(new String[0]);
return encode(array);
return encode(array, addSpecialTokens);
}

/**
Expand All @@ -177,20 +204,44 @@ public Encoding encode(List<String> inputs) {
* @param inputs the input sentences
* @return the {@code Encoding} of the input sentences
*/
public Encoding encode(String[] inputs) {
public Encoding encode(List<String> inputs) {
return encode(inputs, addSpecialTokens);
}

/**
* Returns the {@code Encoding} of the input sentences.
*
* @param inputs the input sentences
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @return the {@code Encoding} of the input sentences
*/
public Encoding encode(String[] inputs, boolean addSpecialTokens) {
long encoding = TokenizersLibrary.LIB.encodeList(getHandle(), inputs, addSpecialTokens);
return toEncoding(encoding);
}

/**
* Returns the {@code Encoding} of the input sentences.
*
* @param inputs the input sentences
* @return the {@code Encoding} of the input sentences
*/
public Encoding encode(String[] inputs) {
return encode(inputs, addSpecialTokens);
}

/**
* Returns the {@code Encoding} of the input sentence in batch.
*
* @param inputs the batch of input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @return the {@code Encoding} of the input sentence in batch
*/
public Encoding[] batchEncode(List<String> inputs) {
public Encoding[] batchEncode(List<String> inputs, boolean addSpecialTokens) {
String[] array = inputs.toArray(new String[0]);
return batchEncode(array);
return batchEncode(array, addSpecialTokens);
}

/**
Expand All @@ -199,7 +250,19 @@ public Encoding[] batchEncode(List<String> inputs) {
* @param inputs the batch of input sentence
* @return the {@code Encoding} of the input sentence in batch
*/
public Encoding[] batchEncode(String[] inputs) {
public Encoding[] batchEncode(List<String> inputs) {
return batchEncode(inputs, addSpecialTokens);
}

/**
* Returns the {@code Encoding} of the input sentence in batch.
*
* @param inputs the batch of input sentence
* @param addSpecialTokens whether to encode the sequence with special tokens relative to their
* model
* @return the {@code Encoding} of the input sentence in batch
*/
public Encoding[] batchEncode(String[] inputs, boolean addSpecialTokens) {
long[] encodings = TokenizersLibrary.LIB.batchEncode(getHandle(), inputs, addSpecialTokens);
Encoding[] ret = new Encoding[encodings.length];
for (int i = 0; i < encodings.length; ++i) {
Expand All @@ -208,14 +271,35 @@ public Encoding[] batchEncode(String[] inputs) {
return ret;
}

/**
* Returns the {@code Encoding} of the input sentence in batch.
*
* @param inputs the batch of input sentence
* @return the {@code Encoding} of the input sentence in batch
*/
public Encoding[] batchEncode(String[] inputs) {
return batchEncode(inputs, addSpecialTokens);
}

/**
* Returns the decoded String from the input ids.
*
* @param ids the input ids
* @param skipSpecialTokens whether to remove special tokens in the decoding
* @return the decoded String from the input ids
*/
public String decode(long[] ids, boolean skipSpecialTokens) {
return TokenizersLibrary.LIB.decode(getHandle(), ids, skipSpecialTokens);
}

/**
* Returns the decoded String from the input ids.
*
* @param ids the input ids
* @return the decoded String from the input ids
*/
public String decode(long[] ids) {
return TokenizersLibrary.LIB.decode(getHandle(), ids, addSpecialTokens);
return decode(ids, !addSpecialTokens);
}

private Encoding toEncoding(long encoding) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,35 @@ public void testTokenizerDecoding() {
101, 3570, 1110, 170, 21162, 1285, 119, 2750, 4250, 146, 112, 173, 1474,
102
});
List<String> expectedDecodings =
List<String> expectedDecodingsNoSpecialTokens =
Arrays.asList(
"Hello, y ' all! How are you?",
"Today is a sunny day. Good weather I ' d say");
List<String> expectedDecodingsWithSpecialTokens =
Arrays.asList(
"[CLS] Hello, y ' all! How are you [UNK]? [SEP]",
"[CLS] Today is a sunny day. Good weather I ' d say [SEP]");
try (HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance("bert-base-cased")) {
for (int i = 0; i < testIds.size(); ++i) {
Assert.assertEquals(tokenizer.decode(testIds.get(i)), expectedDecodings.get(i));
Assert.assertEquals(
tokenizer.decode(testIds.get(i)),
expectedDecodingsWithSpecialTokens.get(i));
Assert.assertEquals(
tokenizer.decode(testIds.get(i), true),
expectedDecodingsNoSpecialTokens.get(i));
}
}

Map<String, String> options = new ConcurrentHashMap<>();
options.put("addSpecialTokens", "false");
try (HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.newInstance("bert-base-cased", options)) {
for (int i = 0; i < testIds.size(); ++i) {
Assert.assertEquals(
tokenizer.decode(testIds.get(i)), expectedDecodingsNoSpecialTokens.get(i));
Assert.assertEquals(
tokenizer.decode(testIds.get(i), false),
expectedDecodingsWithSpecialTokens.get(i));
}
}
}
Expand Down

0 comments on commit 927187b

Please sign in to comment.