From 41b3f5ae422ff6e686ee57ef9f8b46976424ebbb Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Mon, 22 Jan 2024 13:24:07 -0800 Subject: [PATCH 1/6] Fixes. --- tokenizers/src/decoders/sequence.rs | 8 ++++++++ tokenizers/src/tokenizer/mod.rs | 17 +++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/tokenizers/src/decoders/sequence.rs b/tokenizers/src/decoders/sequence.rs index 484df6c95..73169b695 100644 --- a/tokenizers/src/decoders/sequence.rs +++ b/tokenizers/src/decoders/sequence.rs @@ -13,6 +13,14 @@ impl Sequence { pub fn new(decoders: Vec) -> Self { Self { decoders } } + + pub fn get_decoders(&self) -> &[DecoderWrapper] { + &self.decoders + } + + pub fn get_decoders_mut(&mut self) -> &mut [DecoderWrapper] { + &mut self.decoders + } } impl Decoder for Sequence { diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index ae6a64362..01a598423 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -384,6 +384,12 @@ where self } + /// Set the added vocabulary. + pub fn with_added_vocabulary(mut self, added_vocabulary: AddedVocabulary) -> Self { + self.added_vocabulary = added_vocabulary; + self + } + /// Set the trunaction parameters. #[must_use] pub fn with_truncation(mut self, trunc: Option) -> Self { @@ -598,6 +604,17 @@ where &self.model } + /// Set the added vocabulary. + pub fn with_added_vocabulary(&mut self, added_vocabulary: AddedVocabulary) -> &mut Self { + self.added_vocabulary = added_vocabulary.into(); + self + } + + /// Get the added vocabulary + pub fn get_added_vocabulary(&self) -> &AddedVocabulary { + &self.added_vocabulary + } + /// Set the truncation parameters /// /// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()` From 25a7f916acc42b5d00d20fab0eac73aa3ef6e397 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Mon, 22 Jan 2024 13:32:06 -0800 Subject: [PATCH 2/6] Fixes. --- tokenizers/src/tokenizer/added_vocabulary.rs | 13 ++++++++++++- tokenizers/src/tokenizer/mod.rs | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 487fb4479..b7521fde4 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -139,7 +139,7 @@ fn space_rightmost_at_start(sentence: &str) -> usize { /// exist as required. /// #[derive(Clone, Debug)] -pub(super) struct AddedVocabulary { +pub struct AddedVocabulary { /// Contains the mapping from String (token content) to ID. This map contains both special /// tokens and classic added tokens that were added to the this vocabulary. added_tokens_map: HashMap, @@ -192,6 +192,11 @@ impl AddedVocabulary { self.added_tokens_map.len() } + /// Whether or not this vocabulary is empty + pub fn is_empty(&self) -> bool { + self.added_tokens_map.is_empty() + } + /// Get the additional vocabulary pub fn get_vocab(&self) -> &HashMap { &self.added_tokens_map @@ -487,6 +492,12 @@ impl AddedVocabulary { } } +impl Default for AddedVocabulary { + fn default() -> Self { + Self::new() + } +} + #[derive(Debug, Serialize, Deserialize)] pub(super) struct AddedTokenWithId { /// The id assigned to this token diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 01a598423..504026235 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -606,7 +606,7 @@ where /// Set the added vocabulary. pub fn with_added_vocabulary(&mut self, added_vocabulary: AddedVocabulary) -> &mut Self { - self.added_vocabulary = added_vocabulary.into(); + self.added_vocabulary = added_vocabulary; self } From 625b08048559d23427e11460ba6387183677a3a7 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Mon, 22 Jan 2024 21:54:48 -0800 Subject: [PATCH 3/6] Fixes. --- bindings/python/src/tokenizer.rs | 37 +++++++++++++++ tokenizers/src/tokenizer/added_vocabulary.rs | 50 ++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 4e792ef54..529ddab30 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -25,6 +25,7 @@ use super::trainers::PyTrainer; use crate::processors::PyPostProcessor; use crate::utils::{MaybeSizedIterator, PyBufferedIterator}; use std::collections::BTreeMap; +use tk::{AddedToken, AddedVocabulary}; /// Represents a token that can be be added to a :class:`~tokenizers.Tokenizer`. /// It can have special options that defines the way it should behave. @@ -662,6 +663,42 @@ impl PyTokenizer { self.tokenizer.get_vocab(with_added_tokens) } + /// Sets the underlying added tokens vocabulary + /// + /// Args: + /// added_tokens_decoder (:obj:`Dict[int, AddedToken]`): + /// Map from added token ID to :obj:`AddedToken`. + /// encode_special_tokens (:obj:`bool`, defaults to :onj:`False`): + /// Whether or not special tokens should be split when encoding. This is equivalent to ignoring them. + #[pyo3(signature = (added_tokens_decoder, encode_special_tokens = false))] + #[pyo3(text_signature = "(self, added_tokens_decoder, encode_special_tokens=False)")] + fn set_added_tokens_decoder( + &mut self, + added_tokens_decoder: &PyDict, + encode_special_tokens: bool, + ) -> PyResult<()> { + added_tokens_decoder + .iter() + .map(|(key, value)| { + key.extract::().and_then(|key| { + value + .extract::>() + .map(|value| (key, value.get_token())) + }) + }) + .collect::, PyErr>>() + .map(|added_tokens| { + self.tokenizer + .with_added_vocabulary(AddedVocabulary::from_indexed_added_tokens( + added_tokens, + encode_special_tokens, + self.tokenizer.get_model(), + self.tokenizer.get_normalizer(), + )) + })?; + Ok(()) + } + /// Get the underlying vocabulary /// /// Returns: diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index b7521fde4..9b6afa913 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -186,6 +186,56 @@ impl AddedVocabulary { encode_special_tokens: false, } } + + /// Creates a new [AddedVocabulary] from a collection of [AddedToken]s that already have assigned IDs. + /// This constructor is useful for constructing an [AddedVocabulary] from a pre-existing [AddedVocabulary] + /// (e.g., from a serialized [AddedVocabulary]). + pub fn from_indexed_added_tokens( + tokens: HashMap, + encode_special_tokens: bool, + model: &impl Model, + normalizer: Option<&N>, + ) -> Self { + let mut vocabulary = AddedVocabulary::new(); + vocabulary.encode_special_tokens = encode_special_tokens; + + // Handle special tokens (if any). + for token in tokens.values() { + if token.special + && !token.content.is_empty() + && !vocabulary.special_tokens_set.contains(&token.content) + { + vocabulary.special_tokens.push(token.to_owned()); + vocabulary.special_tokens_set.insert(token.content.clone()); + } + } + + for (token_id, token) in tokens { + if token.content.is_empty() || vocabulary.added_tokens_map_r.values().any(|val| *val == token) + { + continue; + } + + vocabulary.added_tokens_map + .entry(token.content.clone()) + .and_modify(|old_id| *old_id = token_id) + .or_insert_with(|| token_id); + + vocabulary.added_tokens_map_r + .entry(token_id) + .and_modify(|t| *t = token.clone()) + .or_insert_with(|| token.clone()); + + if !vocabulary.special_tokens_set.contains(&token.content) { + vocabulary.added_tokens.push(token.clone()); + } + } + + vocabulary.refresh_added_tokens(model, normalizer); + + vocabulary + } + /// Size of the additional vocabulary #[allow(dead_code)] // Suppress the "method is never used" warning pub fn len(&self) -> usize { From 97f845b4e15fcf1a32bfbba33c992949690c900e Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Tue, 23 Jan 2024 04:01:27 -0800 Subject: [PATCH 4/6] Format fix. --- tokenizers/src/tokenizer/added_vocabulary.rs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 9b6afa913..436f5376e 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -198,7 +198,7 @@ impl AddedVocabulary { ) -> Self { let mut vocabulary = AddedVocabulary::new(); vocabulary.encode_special_tokens = encode_special_tokens; - + // Handle special tokens (if any). for token in tokens.values() { if token.special @@ -209,19 +209,25 @@ impl AddedVocabulary { vocabulary.special_tokens_set.insert(token.content.clone()); } } - + for (token_id, token) in tokens { - if token.content.is_empty() || vocabulary.added_tokens_map_r.values().any(|val| *val == token) + if token.content.is_empty() + || vocabulary + .added_tokens_map_r + .values() + .any(|val| *val == token) { continue; } - vocabulary.added_tokens_map + vocabulary + .added_tokens_map .entry(token.content.clone()) .and_modify(|old_id| *old_id = token_id) .or_insert_with(|| token_id); - vocabulary.added_tokens_map_r + vocabulary + .added_tokens_map_r .entry(token_id) .and_modify(|t| *t = token.clone()) .or_insert_with(|| token.clone()); From 164f8fe46148adb0f45c80a7bff52eecb39a8385 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Tue, 23 Jan 2024 04:08:24 -0800 Subject: [PATCH 5/6] . --- bindings/python/src/tokenizer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 529ddab30..18b2f5016 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -668,7 +668,7 @@ impl PyTokenizer { /// Args: /// added_tokens_decoder (:obj:`Dict[int, AddedToken]`): /// Map from added token ID to :obj:`AddedToken`. - /// encode_special_tokens (:obj:`bool`, defaults to :onj:`False`): + /// encode_special_tokens (:obj:`bool`, defaults to :obj:`False`): /// Whether or not special tokens should be split when encoding. This is equivalent to ignoring them. #[pyo3(signature = (added_tokens_decoder, encode_special_tokens = false))] #[pyo3(text_signature = "(self, added_tokens_decoder, encode_special_tokens=False)")] From 082f800904bf4af8a47b6eaf93aeb4fa625071c0 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Tue, 23 Jan 2024 04:23:03 -0800 Subject: [PATCH 6/6] . --- bindings/python/py_src/tokenizers/__init__.pyi | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi index 7c21c5b56..1e05891b2 100644 --- a/bindings/python/py_src/tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/__init__.pyi @@ -1046,6 +1046,17 @@ class Tokenizer: Whether the JSON file should be pretty formatted. """ pass + def set_added_tokens_decoder(self, added_tokens_decoder, encode_special_tokens=False): + """ + Sets the underlying added tokens vocabulary + + Args: + added_tokens_decoder (:obj:`Dict[int, AddedToken]`): + Map from added token ID to :obj:`AddedToken`. + encode_special_tokens (:obj:`bool`, defaults to :obj:`False`): + Whether or not special tokens should be split when encoding. This is equivalent to ignoring them. + """ + pass def to_str(self, pretty=False): """ Gets a serialized string representing this :class:`~tokenizers.Tokenizer`.