diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index d2e61ee1c..da8c74a9c 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,11 +1,12 @@ use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; -use numpy::PyArray1; +use numpy::npyffi; use pyo3::class::basic::CompareOp; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; +use pyo3::AsPyPointer; use tk::models::bpe::BPE; use tk::tokenizer::{ Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl, @@ -258,8 +259,24 @@ impl<'s> From> for tk::InputSequence<'s> { struct PyArrayUnicode(Vec); impl FromPyObject<'_> for PyArrayUnicode { fn extract(ob: &PyAny) -> PyResult { - let array = ob.downcast::>()?; - let arr = array.as_array_ptr(); + if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 { + return Err(exceptions::PyTypeError::new_err("Expected an np.array")); + } + let arr = ob.as_ptr() as *mut npyffi::PyArrayObject; + if unsafe { (*arr).nd } != 1 { + return Err(exceptions::PyTypeError::new_err( + "Expected a 1 dimensional np.array", + )); + } + if unsafe { (*arr).flags } + & (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS) + == 0 + { + return Err(exceptions::PyTypeError::new_err( + "Expected a contiguous np.array", + )); + } + let n_elem = unsafe { *(*arr).dimensions } as usize; let (type_num, elsize, alignment, data) = unsafe { let desc = (*arr).descr; ( @@ -269,7 +286,6 @@ impl FromPyObject<'_> for PyArrayUnicode { (*arr).data, ) }; - let n_elem = array.shape()[0]; // type_num == 19 => Unicode if type_num != 19 { @@ -310,10 +326,27 @@ impl From for tk::InputSequence<'_> { struct PyArrayStr(Vec); impl FromPyObject<'_> for PyArrayStr { fn extract(ob: &PyAny) -> PyResult { - let array = ob.downcast::>()?; - let arr = array.as_array_ptr(); + if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 { + return Err(exceptions::PyTypeError::new_err("Expected an np.array")); + } + let arr = ob.as_ptr() as *mut npyffi::PyArrayObject; + + if unsafe { (*arr).nd } != 1 { + return Err(exceptions::PyTypeError::new_err( + "Expected a 1 dimensional np.array", + )); + } + if unsafe { (*arr).flags } + & (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS) + == 0 + { + return Err(exceptions::PyTypeError::new_err( + "Expected a contiguous np.array", + )); + } + let n_elem = unsafe { *(*arr).dimensions } as usize; + let (type_num, data) = unsafe { ((*(*arr).descr).type_num, (*arr).data) }; - let n_elem = array.shape()[0]; if type_num != 17 { return Err(exceptions::PyTypeError::new_err(