Skip to content

Commit

Permalink
fix: Don't erroneously mess with huggingface offset ranges
Browse files Browse the repository at this point in the history
I had previously assumed the offsets were non-inclusive and needed to become inclusive, however this isn't actually the case, and may not be consistent beyond ByteLevel preprocessors. It seems the offsets refer to the words themselves, and don't need to be adjusted. And therefore we also don't have to adjust for prefixing either, as the offsets stay the same even for prefixed text.
  • Loading branch information
benbrandt committed Apr 3, 2024
1 parent 7309bca commit 8ba3841
Showing 1 changed file with 10 additions and 25 deletions.
35 changes: 10 additions & 25 deletions src/chunk_size/huggingface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,18 @@ impl ChunkSizer for &Tokenizer {
.encode(chunk, false)
.expect("Unable to tokenize the following string {chunk}");

let padding_params = self.get_padding();
let pad_id = self.get_padding().map(|params| params.pad_id);

let mut offsets = encoding
let offsets = encoding
.get_ids()
.iter()
.zip(encoding.get_offsets())
// Skip padding tokens at beginning and end so they don't count towards the chunk size
.skip_while(|&(id, _)| padding_params.map_or(false, |params| id == &params.pad_id))
.take_while(|&(id, _)| padding_params.map_or(true, |params| id != &params.pad_id))
.map(|(_, (start, end))| {
let end = *end + 1;
*start..end
})
.collect::<Vec<_>>();
.skip_while(|&(id, _)| pad_id.map_or(false, |pad_id| id == &pad_id))
.take_while(|&(id, _)| pad_id.map_or(true, |pad_id| id != &pad_id))
.map(|(_, (start, end))| *start..*end);

// Sometimes the offsets are off by one because of whitespace prefixing
let prefixed = offsets.last().is_some_and(|r| r.end != chunk.len());

if prefixed {
for range in &mut offsets {
if range.start != 0 {
range.start -= 1;
}
range.end -= 1;
}
}

ChunkSize::from_offsets(offsets.into_iter(), capacity)
ChunkSize::from_offsets(offsets, capacity)
}
}

Expand All @@ -68,19 +52,20 @@ mod tests {
let offsets = tokenizer.chunk_size(" An apple a", &capacity);
assert_eq!(
offsets,
ChunkSize::from_offsets([0..3, 3..9, 9..11].into_iter(), &capacity)
ChunkSize::from_offsets([1..3, 4..9, 10..11].into_iter(), &capacity)
);
}

#[test]
fn returns_offsets_handles_prefix() {
let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
let tokenizer =
tokenizers::Tokenizer::from_file("./tests/tokenizers/huggingface.json").unwrap();

let capacity = 10;
let offsets = tokenizer.chunk_size("An apple a", &capacity);
assert_eq!(
offsets,
ChunkSize::from_offsets([0..2, 2..8, 8..10].into_iter(), &capacity)
ChunkSize::from_offsets([0..2, 3..8, 9..10].into_iter(), &capacity)
);
}

Expand Down

0 comments on commit 8ba3841

Please sign in to comment.