Skip to content

Commit

Permalink
fix pylist (#1673)
Browse files Browse the repository at this point in the history
* fix pylist

* add comment about why we use PySequence

* style

* fix encode batch fast as well

* Update bindings/python/src/tokenizer.rs

Co-authored-by: Nicolas Patry <[email protected]>

* fix with capacity

* stub :)

---------

Co-authored-by: Nicolas Patry <[email protected]>
  • Loading branch information
ArthurZucker and Narsil authored Nov 5, 2024
1 parent 0f3a3f9 commit 5e223ce
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 28 deletions.
4 changes: 3 additions & 1 deletion bindings/python/py_src/tokenizers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,9 @@ class Tokenizer:
def encode_batch(self, input, is_pretokenized=False, add_special_tokens=True):
"""
Encode the given batch of inputs. This method accept both raw text sequences
as well as already pre-tokenized sequences.
as well as already pre-tokenized sequences. The reason we use `PySequence` is
because it allows type checking with zero-cost (according to PyO3) as we don't
have to convert to check.
Example:
Here are some examples of the inputs that are accepted::
Expand Down
54 changes: 27 additions & 27 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,9 @@ impl PyTokenizer {
}

/// Encode the given batch of inputs. This method accept both raw text sequences
/// as well as already pre-tokenized sequences.
/// as well as already pre-tokenized sequences. The reason we use `PySequence` is
/// because it allows type checking with zero-cost (according to PyO3) as we don't
/// have to convert to check.
///
/// Example:
/// Here are some examples of the inputs that are accepted::
Expand Down Expand Up @@ -1030,25 +1032,24 @@ impl PyTokenizer {
fn encode_batch(
&self,
py: Python<'_>,
input: Bound<'_, PyList>,
input: Bound<'_, PySequence>,
is_pretokenized: bool,
add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> {
let input: Vec<tk::EncodeInput> = input
.into_iter()
.map(|o| {
let input: tk::EncodeInput = if is_pretokenized {
o.extract::<PreTokenizedEncodeInput>()?.into()
} else {
o.extract::<TextEncodeInput>()?.into()
};
Ok(input)
})
.collect::<PyResult<Vec<tk::EncodeInput>>>()?;
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len()?);
for i in 0..input.len()? {
let item = input.get_item(i)?;
let item: tk::EncodeInput = if is_pretokenized {
item.extract::<PreTokenizedEncodeInput>()?.into()
} else {
item.extract::<TextEncodeInput>()?.into()
};
items.push(item);
}
py.allow_threads(|| {
ToPyResult(
self.tokenizer
.encode_batch_char_offsets(input, add_special_tokens)
.encode_batch_char_offsets(items, add_special_tokens)
.map(|encodings| encodings.into_iter().map(|e| e.into()).collect()),
)
.into()
Expand Down Expand Up @@ -1091,25 +1092,24 @@ impl PyTokenizer {
fn encode_batch_fast(
&self,
py: Python<'_>,
input: Bound<'_, PyList>,
input: Bound<'_, PySequence>,
is_pretokenized: bool,
add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> {
let input: Vec<tk::EncodeInput> = input
.into_iter()
.map(|o| {
let input: tk::EncodeInput = if is_pretokenized {
o.extract::<PreTokenizedEncodeInput>()?.into()
} else {
o.extract::<TextEncodeInput>()?.into()
};
Ok(input)
})
.collect::<PyResult<Vec<tk::EncodeInput>>>()?;
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len()?);
for i in 0..input.len()? {
let item = input.get_item(i)?;
let item: tk::EncodeInput = if is_pretokenized {
item.extract::<PreTokenizedEncodeInput>()?.into()
} else {
item.extract::<TextEncodeInput>()?.into()
};
items.push(item);
}
py.allow_threads(|| {
ToPyResult(
self.tokenizer
.encode_batch_fast(input, add_special_tokens)
.encode_batch_fast(items, add_special_tokens)
.map(|encodings| encodings.into_iter().map(|e| e.into()).collect()),
)
.into()
Expand Down

0 comments on commit 5e223ce

Please sign in to comment.