diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi index 7c21c5b56..1e05891b2 100644 --- a/bindings/python/py_src/tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/__init__.pyi @@ -1046,6 +1046,17 @@ class Tokenizer: Whether the JSON file should be pretty formatted. """ pass + def set_added_tokens_decoder(self, added_tokens_decoder, encode_special_tokens=False): + """ + Sets the underlying added tokens vocabulary + + Args: + added_tokens_decoder (:obj:`Dict[int, AddedToken]`): + Map from added token ID to :obj:`AddedToken`. + encode_special_tokens (:obj:`bool`, defaults to :obj:`False`): + Whether or not special tokens should be split when encoding. This is equivalent to ignoring them. + """ + pass def to_str(self, pretty=False): """ Gets a serialized string representing this :class:`~tokenizers.Tokenizer`. diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 4e792ef54..18b2f5016 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -25,6 +25,7 @@ use super::trainers::PyTrainer; use crate::processors::PyPostProcessor; use crate::utils::{MaybeSizedIterator, PyBufferedIterator}; use std::collections::BTreeMap; +use tk::{AddedToken, AddedVocabulary}; /// Represents a token that can be be added to a :class:`~tokenizers.Tokenizer`. /// It can have special options that defines the way it should behave. @@ -662,6 +663,42 @@ impl PyTokenizer { self.tokenizer.get_vocab(with_added_tokens) } + /// Sets the underlying added tokens vocabulary + /// + /// Args: + /// added_tokens_decoder (:obj:`Dict[int, AddedToken]`): + /// Map from added token ID to :obj:`AddedToken`. + /// encode_special_tokens (:obj:`bool`, defaults to :obj:`False`): + /// Whether or not special tokens should be split when encoding. This is equivalent to ignoring them. + #[pyo3(signature = (added_tokens_decoder, encode_special_tokens = false))] + #[pyo3(text_signature = "(self, added_tokens_decoder, encode_special_tokens=False)")] + fn set_added_tokens_decoder( + &mut self, + added_tokens_decoder: &PyDict, + encode_special_tokens: bool, + ) -> PyResult<()> { + added_tokens_decoder + .iter() + .map(|(key, value)| { + key.extract::().and_then(|key| { + value + .extract::>() + .map(|value| (key, value.get_token())) + }) + }) + .collect::, PyErr>>() + .map(|added_tokens| { + self.tokenizer + .with_added_vocabulary(AddedVocabulary::from_indexed_added_tokens( + added_tokens, + encode_special_tokens, + self.tokenizer.get_model(), + self.tokenizer.get_normalizer(), + )) + })?; + Ok(()) + } + /// Get the underlying vocabulary /// /// Returns: diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index 484df6c95..73169b695 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -13,6 +13,14 @@ impl Sequence { pub fn new(decoders: Vec) -> Self { Self { decoders } } + + pub fn get_decoders(&self) -> &[DecoderWrapper] { + &self.decoders + } + + pub fn get_decoders_mut(&mut self) -> &mut [DecoderWrapper] { + &mut self.decoders + } } impl Decoder for Sequence { diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 487fb4479..436f5376e 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -139,7 +139,7 @@ fn space_rightmost_at_start(sentence: &str) -> usize { /// exist as required. /// #[derive(Clone, Debug)] -pub(super) struct AddedVocabulary { +pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. added_tokens_map: HashMap, @@ -186,12 +186,73 @@ impl AddedVocabulary { encode_special_tokens: false, } } + + /// Creates a new [AddedVocabulary] from a collection of [AddedToken]s that already have assigned IDs. + /// This constructor is useful for constructing an [AddedVocabulary] from a pre-existing [AddedVocabulary] + /// (e.g., from a serialized [AddedVocabulary]). + pub fn from_indexed_added_tokens( + tokens: HashMap, + encode_special_tokens: bool, + model: &impl Model, + normalizer: Option<&N>, + ) -> Self { + let mut vocabulary = AddedVocabulary::new(); + vocabulary.encode_special_tokens = encode_special_tokens; + + // Handle special tokens (if any). + for token in tokens.values() { + if token.special + && !token.content.is_empty() + && !vocabulary.special_tokens_set.contains(&token.content) + { + vocabulary.special_tokens.push(token.to_owned()); + vocabulary.special_tokens_set.insert(token.content.clone()); + } + } + + for (token_id, token) in tokens { + if token.content.is_empty() + || vocabulary + .added_tokens_map_r + .values() + .any(|val| *val == token) + { + continue; + } + + vocabulary + .added_tokens_map + .entry(token.content.clone()) + .and_modify(|old_id| *old_id = token_id) + .or_insert_with(|| token_id); + + vocabulary + .added_tokens_map_r + .entry(token_id) + .and_modify(|t| *t = token.clone()) + .or_insert_with(|| token.clone()); + + if !vocabulary.special_tokens_set.contains(&token.content) { + vocabulary.added_tokens.push(token.clone()); + } + } + + vocabulary.refresh_added_tokens(model, normalizer); + + vocabulary + } + /// Size of the additional vocabulary #[allow(dead_code)] // Suppress the "method is never used" warning pub fn len(&self) -> usize { self.added_tokens_map.len() } + /// Whether or not this vocabulary is empty + pub fn is_empty(&self) -> bool { + self.added_tokens_map.is_empty() + } + /// Get the additional vocabulary pub fn get_vocab(&self) -> &HashMap { &self.added_tokens_map @@ -487,6 +548,12 @@ impl AddedVocabulary { } } +impl Default for AddedVocabulary { + fn default() -> Self { + Self::new() + } +} + #[derive(Debug, Serialize, Deserialize)] pub(super) struct AddedTokenWithId { /// The id assigned to this token diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index ae6a64362..504026235 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -384,6 +384,12 @@ where self } + /// Set the added vocabulary. + pub fn with_added_vocabulary(mut self, added_vocabulary: AddedVocabulary) -> Self { + self.added_vocabulary = added_vocabulary; + self + } + /// Set the trunaction parameters. #[must_use] pub fn with_truncation(mut self, trunc: Option) -> Self { @@ -598,6 +604,17 @@ where &self.model } + /// Set the added vocabulary. + pub fn with_added_vocabulary(&mut self, added_vocabulary: AddedVocabulary) -> &mut Self { + self.added_vocabulary = added_vocabulary; + self + } + + /// Get the added vocabulary + pub fn get_added_vocabulary(&self) -> &AddedVocabulary { + &self.added_vocabulary + } + /// Set the truncation parameters /// /// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()`