diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index d2e61ee1c..372bb7755 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,11 +1,12 @@ use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; -use numpy::PyArray1; +use numpy::{npyffi, PyArray1}; use pyo3::class::basic::CompareOp; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; +use pyo3::AsPyPointer; use tk::models::bpe::BPE; use tk::tokenizer::{ Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, @@ -258,8 +259,24 @@ impl<'s> From> for tk::InputSequence<'s> { struct PyArrayUnicode(Vec); impl FromPyObject<'_> for PyArrayUnicode { fn extract(ob: &PyAny) -> PyResult { - let array = ob.downcast::>()?; - let arr = array.as_array_ptr(); + if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 { + return Err(exceptions::PyTypeError::new_err("Expected an np.array")); + } + let arr = ob.as_ptr() as *mut npyffi::PyArrayObject; + if unsafe { (*arr).nd } != 1 { + return Err(exceptions::PyTypeError::new_err( + "Expected a 1 dimensional np.array", + )); + } + if unsafe { (*arr).flags } + & (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS) + == 0 + { + return Err(exceptions::PyTypeError::new_err( + "Expected a contiguous np.array", + )); + } + let n_elem = unsafe { *(*arr).dimensions } as usize; let (type_num, elsize, alignment, data) = unsafe { let desc = (*arr).descr; ( @@ -269,10 +286,8 @@ impl FromPyObject<'_> for PyArrayUnicode { (*arr).data, ) }; - let n_elem = array.shape()[0]; - // type_num == 19 => Unicode - if type_num != 19 { + if type_num != npyffi::types::NPY_TYPES::NPY_UNICODE as i32 { return Err(exceptions::PyTypeError::new_err( "Expected a np.array[dtype='U']", )); @@ -289,8 +304,7 @@ impl FromPyObject<'_> for PyArrayUnicode { bytes.as_ptr() as *const _, elsize as isize / alignment as isize, ); - let gil = Python::acquire_gil(); - let py = gil.python(); + let py = ob.py(); let obj = PyObject::from_owned_ptr(py, unicode); let s = obj.cast_as::(py)?; Ok(s.to_string_lossy().trim_matches(char::from(0)).to_owned()) @@ -310,32 +324,18 @@ impl From for tk::InputSequence<'_> { struct PyArrayStr(Vec); impl FromPyObject<'_> for PyArrayStr { fn extract(ob: &PyAny) -> PyResult { - let array = ob.downcast::>()?; - let arr = array.as_array_ptr(); - let (type_num, data) = unsafe { ((*(*arr).descr).type_num, (*arr).data) }; - let n_elem = array.shape()[0]; - - if type_num != 17 { - return Err(exceptions::PyTypeError::new_err( - "Expected a np.array[dtype='O']", - )); - } - - unsafe { - let objects = std::slice::from_raw_parts(data as *const PyObject, n_elem); - - let seq = objects - .iter() - .map(|obj| { - let gil = Python::acquire_gil(); - let py = gil.python(); - let s = obj.cast_as::(py)?; - Ok(s.to_string_lossy().into_owned()) - }) - .collect::>>()?; + let array = ob.downcast::>()?; + let seq = array + .readonly() + .as_array() + .iter() + .map(|obj| { + let s = obj.cast_as::(ob.py())?; + Ok(s.to_string_lossy().into_owned()) + }) + .collect::>>()?; - Ok(Self(seq)) - } + Ok(Self(seq)) } } impl From for tk::InputSequence<'_> {