Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tokenizers] support stride in tokenizers #2006

Merged
merged 1 commit into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions extensions/tokenizers/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,25 @@ pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_
array
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_getOverflowing(
env: JNIEnv,
_: JObject,
handle: jlong,
) -> jlongArray {
let encoding = cast_handle::<Encoding>(handle);
let handles = encoding
.get_overflowing()
.clone()
.into_iter()
.map(|c| to_handle(c))
.collect::<Vec<_>>();
let size = handles.len() as jsize;
let ret = env.new_long_array(size).unwrap();
env.set_long_array_region(ret, 0, &handles).unwrap();
ret
}

#[no_mangle]
pub extern "system" fn Java_ai_djl_huggingface_tokenizers_jni_TokenizersLibrary_decode(
env: JNIEnv,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class Encoding {
private long[] attentionMask;
private long[] specialTokenMask;
private CharSpan[] charTokenSpans;
private Encoding[] overflowing;

protected Encoding(
long[] ids,
Expand All @@ -32,14 +33,16 @@ protected Encoding(
long[] wordIds,
long[] attentionMask,
long[] specialTokenMask,
CharSpan[] charTokenSpans) {
CharSpan[] charTokenSpans,
Encoding[] overflowing) {
this.ids = ids;
this.typeIds = typeIds;
this.tokens = tokens;
this.wordIds = wordIds;
this.attentionMask = attentionMask;
this.specialTokenMask = specialTokenMask;
this.charTokenSpans = charTokenSpans;
this.overflowing = overflowing;
}

/**
Expand Down Expand Up @@ -104,4 +107,13 @@ public long[] getSpecialTokenMask() {
public CharSpan[] getCharTokenSpans() {
return charTokenSpans;
}

/**
* Returns an array of overflowing encodings.
*
* @return the array of overflowing encodings
*/
public Encoding[] getOverflowing() {
return overflowing;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,23 @@ private Encoding toEncoding(long encoding) {
long[] attentionMask = TokenizersLibrary.LIB.getAttentionMask(encoding);
long[] specialTokenMask = TokenizersLibrary.LIB.getSpecialTokenMask(encoding);
CharSpan[] charSpans = TokenizersLibrary.LIB.getTokenCharSpans(encoding);
long[] overflowingHandles = TokenizersLibrary.LIB.getOverflowing(encoding);

Encoding[] overflowing = new Encoding[overflowingHandles.length];
for (int i = 0; i < overflowingHandles.length; ++i) {
overflowing[i] = toEncoding(overflowingHandles[i]);
}

TokenizersLibrary.LIB.deleteEncoding(encoding);
return new Encoding(
ids, typeIds, tokens, wordIds, attentionMask, specialTokenMask, charSpans);
ids,
typeIds,
tokens,
wordIds,
attentionMask,
specialTokenMask,
charSpans,
overflowing);
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -621,6 +634,18 @@ public Builder optPadToMultipleOf(int padToMultipleOf) {
return this;
}

/**
* Sets the stride to use in overflow overlap when truncating sequences longer than the
* model supports.
*
* @param stride the number of tokens to overlap when truncating long sequences
* @return this builder
*/
public Builder optStride(int stride) {
options.put("stride", String.valueOf(stride));
return this;
}

/**
* Configures the builder with the arguments.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ public native long encodeDual(

public native CharSpan[] getTokenCharSpans(long encoding);

public native long[] getOverflowing(long encoding);

public native String decode(long tokenizer, long[] ids, boolean addSpecialTokens);

public native String getTruncationStrategy(long tokenizer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,52 @@ public void testNoTruncationAndAllPaddings() throws IOException {
}
}

@Test
public void testTruncationStride() throws IOException {
try (HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.builder()
.optTokenizerName("bert-base-cased")
.optAddSpecialTokens(false)
.optTruncation(true)
.optMaxLength(3)
.optStride(1)
.build()) {
String[] inputs = {"Hello there my good friend", "How are you today"};
Encoding[] encodings = tokenizer.batchEncode(inputs);
int[] expectedOverflowEncodings = {1, 1};
int[][] expectedNumberOfOverflowingTokens = {{3}, {2}};
for (int i = 0; i < encodings.length; ++i) {
Assert.assertEquals(
encodings[i].getOverflowing().length, expectedOverflowEncodings[i]);
for (int j = 0; j < expectedOverflowEncodings[i]; ++j) {
Assert.assertEquals(
encodings[i].getOverflowing()[j].getTokens().length,
expectedNumberOfOverflowingTokens[i][j]);
}
}
}
try (HuggingFaceTokenizer tokenizer =
HuggingFaceTokenizer.builder()
.optTokenizerName("bert-base-cased")
.optAddSpecialTokens(false)
.optTruncation(true)
.optMaxLength(8)
.optStride(2)
.build()) {
String text = "Hello there my friend I am happy to see you";
String textPair = "How are you my friend";
Encoding[] overflowing = tokenizer.encode(text, textPair).getOverflowing();

int expectedNumberOfOverflowEncodings = 7;
Assert.assertEquals(overflowing.length, expectedNumberOfOverflowEncodings);
int[] expectedOverflowEncodingLengths = {8, 7, 8, 7, 8, 7, 7};
for (int i = 0; i < expectedNumberOfOverflowEncodings; ++i) {
Assert.assertEquals(
overflowing[i].getIds().length, expectedOverflowEncodingLengths[i]);
}
}
}

@Test
public void testTruncationAndPaddingForPairInputs() throws IOException {
String text = "Hello there my friend";
Expand Down