Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for building an AddedVocabulary based on a pre-existing AddedVocabulary. #1444

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions bindings/python/py_src/tokenizers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,17 @@ class Tokenizer:
Whether the JSON file should be pretty formatted.
"""
pass
def set_added_tokens_decoder(self, added_tokens_decoder, encode_special_tokens=False):
"""
Sets the underlying added tokens vocabulary

Args:
added_tokens_decoder (:obj:`Dict[int, AddedToken]`):
Map from added token ID to :obj:`AddedToken`.
encode_special_tokens (:obj:`bool`, defaults to :obj:`False`):
Whether or not special tokens should be split when encoding. This is equivalent to ignoring them.
"""
pass
def to_str(self, pretty=False):
"""
Gets a serialized string representing this :class:`~tokenizers.Tokenizer`.
Expand Down
37 changes: 37 additions & 0 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use super::trainers::PyTrainer;
use crate::processors::PyPostProcessor;
use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
use std::collections::BTreeMap;
use tk::{AddedToken, AddedVocabulary};

/// Represents a token that can be be added to a :class:`~tokenizers.Tokenizer`.
/// It can have special options that defines the way it should behave.
Expand Down Expand Up @@ -662,6 +663,42 @@ impl PyTokenizer {
self.tokenizer.get_vocab(with_added_tokens)
}

/// Sets the underlying added tokens vocabulary
///
/// Args:
/// added_tokens_decoder (:obj:`Dict[int, AddedToken]`):
/// Map from added token ID to :obj:`AddedToken`.
/// encode_special_tokens (:obj:`bool`, defaults to :obj:`False`):
/// Whether or not special tokens should be split when encoding. This is equivalent to ignoring them.
#[pyo3(signature = (added_tokens_decoder, encode_special_tokens = false))]
#[pyo3(text_signature = "(self, added_tokens_decoder, encode_special_tokens=False)")]
fn set_added_tokens_decoder(
&mut self,
added_tokens_decoder: &PyDict,
encode_special_tokens: bool,
) -> PyResult<()> {
added_tokens_decoder
.iter()
.map(|(key, value)| {
key.extract::<u32>().and_then(|key| {
value
.extract::<PyRefMut<PyAddedToken>>()
.map(|value| (key, value.get_token()))
})
})
.collect::<Result<HashMap<u32, AddedToken>, PyErr>>()
.map(|added_tokens| {
self.tokenizer
.with_added_vocabulary(AddedVocabulary::from_indexed_added_tokens(
added_tokens,
encode_special_tokens,
self.tokenizer.get_model(),
self.tokenizer.get_normalizer(),
))
})?;
Ok(())
}

/// Get the underlying vocabulary
///
/// Returns:
Expand Down
8 changes: 8 additions & 0 deletions tokenizers/src/decoders/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ impl Sequence {
pub fn new(decoders: Vec<DecoderWrapper>) -> Self {
Self { decoders }
}

pub fn get_decoders(&self) -> &[DecoderWrapper] {
&self.decoders
}

pub fn get_decoders_mut(&mut self) -> &mut [DecoderWrapper] {
&mut self.decoders
}
}

impl Decoder for Sequence {
Expand Down
69 changes: 68 additions & 1 deletion tokenizers/src/tokenizer/added_vocabulary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ fn space_rightmost_at_start(sentence: &str) -> usize {
/// exist as required.
///
#[derive(Clone, Debug)]
pub(super) struct AddedVocabulary {
pub struct AddedVocabulary {
/// Contains the mapping from String (token content) to ID. This map contains both special
/// tokens and classic added tokens that were added to the this vocabulary.
added_tokens_map: HashMap<String, u32>,
Expand Down Expand Up @@ -186,12 +186,73 @@ impl AddedVocabulary {
encode_special_tokens: false,
}
}

/// Creates a new [AddedVocabulary] from a collection of [AddedToken]s that already have assigned IDs.
/// This constructor is useful for constructing an [AddedVocabulary] from a pre-existing [AddedVocabulary]
/// (e.g., from a serialized [AddedVocabulary]).
pub fn from_indexed_added_tokens<N: Normalizer>(
tokens: HashMap<u32, AddedToken>,
encode_special_tokens: bool,
model: &impl Model,
normalizer: Option<&N>,
) -> Self {
let mut vocabulary = AddedVocabulary::new();
vocabulary.encode_special_tokens = encode_special_tokens;

// Handle special tokens (if any).
for token in tokens.values() {
if token.special
&& !token.content.is_empty()
&& !vocabulary.special_tokens_set.contains(&token.content)
{
vocabulary.special_tokens.push(token.to_owned());
vocabulary.special_tokens_set.insert(token.content.clone());
}
}

for (token_id, token) in tokens {
if token.content.is_empty()
|| vocabulary
.added_tokens_map_r
.values()
.any(|val| *val == token)
{
continue;
}

vocabulary
.added_tokens_map
.entry(token.content.clone())
.and_modify(|old_id| *old_id = token_id)
.or_insert_with(|| token_id);

vocabulary
.added_tokens_map_r
.entry(token_id)
.and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone());

if !vocabulary.special_tokens_set.contains(&token.content) {
vocabulary.added_tokens.push(token.clone());
}
}

vocabulary.refresh_added_tokens(model, normalizer);

vocabulary
}

/// Size of the additional vocabulary
#[allow(dead_code)] // Suppress the "method is never used" warning
pub fn len(&self) -> usize {
self.added_tokens_map.len()
}

/// Whether or not this vocabulary is empty
pub fn is_empty(&self) -> bool {
self.added_tokens_map.is_empty()
}

/// Get the additional vocabulary
pub fn get_vocab(&self) -> &HashMap<String, u32> {
&self.added_tokens_map
Expand Down Expand Up @@ -487,6 +548,12 @@ impl AddedVocabulary {
}
}

impl Default for AddedVocabulary {
fn default() -> Self {
Self::new()
}
}

#[derive(Debug, Serialize, Deserialize)]
pub(super) struct AddedTokenWithId {
/// The id assigned to this token
Expand Down
17 changes: 17 additions & 0 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,12 @@ where
self
}

/// Set the added vocabulary.
pub fn with_added_vocabulary(mut self, added_vocabulary: AddedVocabulary) -> Self {
self.added_vocabulary = added_vocabulary;
self
}

/// Set the trunaction parameters.
#[must_use]
pub fn with_truncation(mut self, trunc: Option<TruncationParams>) -> Self {
Expand Down Expand Up @@ -598,6 +604,17 @@ where
&self.model
}

/// Set the added vocabulary.
pub fn with_added_vocabulary(&mut self, added_vocabulary: AddedVocabulary) -> &mut Self {
self.added_vocabulary = added_vocabulary;
self
}

/// Get the added vocabulary
pub fn get_added_vocabulary(&self) -> &AddedVocabulary {
&self.added_vocabulary
}

/// Set the truncation parameters
///
/// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()`
Expand Down
Loading